[Core] Reduce TTFT with concurrent partial prefills (#10235)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
Co-authored-by: Prashant Gupta <prashantgupta@us.ibm.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
Joe Runde
2025-02-14 16:36:07 -07:00
committed by GitHub
parent 5e5c8e091e
commit 3bcb8c75da
6 changed files with 699 additions and 106 deletions

View File

@ -8,7 +8,6 @@ prefill requests are chunked.
Run `pytest tests/models/test_chunked_prefill.py`.
"""
import os
from contextlib import nullcontext
import pytest
@ -233,7 +232,6 @@ def test_with_prefix_caching(
max_num_batched_tokens = max_num_seqs = chunk_size
outputs = {} # type: ignore
check_result = True
for enable in (True, False):
with vllm_runner(
model,
@ -245,25 +243,17 @@ def test_with_prefix_caching(
enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs,
) as vllm_model:
# It should fail when prefix caching is enable and chunk
# size is not a multiple of block size (16).
should_fail = chunk_size % 16 != 0 and enable
check_result &= not should_fail
outputs[enable] = []
# Send the request one-by-one to ensure the cache is populated.
with pytest.raises(ValueError) if should_fail else nullcontext():
for prompt in full_prompts:
outputs[enable] += vllm_model.generate_greedy([prompt],
max_tokens)
for prompt in full_prompts:
outputs[enable] += vllm_model.generate_greedy([prompt],
max_tokens)
# Check results only if we did not expect a failure.
if check_result:
check_outputs_equal(
outputs_0_lst=outputs[False],
outputs_1_lst=outputs[True],
name_0="w/o prefix caching",
name_1="with prefix caching",
)
check_outputs_equal(
outputs_0_lst=outputs[False],
outputs_1_lst=outputs[True],
name_0="w/o prefix caching",
name_1="with prefix caching",
)
@pytest.mark.parametrize("model", ["facebook/opt-125m"])

View File

@ -7,6 +7,9 @@ import pytest # noqa
from vllm.config import CacheConfig, SchedulerConfig
from vllm.core.scheduler import Scheduler
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob, SequenceGroup
from .utils import create_dummy_prompt
@ -16,7 +19,7 @@ def get_sequence_groups(scheduler_output):
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
def append_new_token(seq_group, token_id: int):
def append_new_token(seq_group: SequenceGroup, token_id: int):
for seq in seq_group.get_seqs():
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
@ -123,6 +126,232 @@ def test_chunk():
assert out.num_batched_tokens == 57
def test_concurrent_chunking():
"""Verify prefills are chunked properly when
--max-num-partial-prefills is > 1"""
block_size = 4
max_seqs = 60
max_model_len = 2000
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True,
max_num_partial_prefills=2, # Up to 2 partial prefills at a time
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 32
cache_config.num_gpu_blocks = 32
scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []
# Add seq groups to scheduler.
for i in range(2):
_, seq_group = create_dummy_prompt(str(i),
prompt_length=60,
block_size=block_size)
scheduler.add_seq_group(seq_group)
running.append(seq_group)
# Verify both requests are chunked with half of max_num_batched_tokens each
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
assert seq_group_meta[0].token_chunk_size == 32
assert seq_group_meta[1].token_chunk_size == 32
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 64
# After one iteration, both should have 60 - 32 = 28 tokens left to prefill
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
assert seq_group_meta[0].token_chunk_size == 28
assert seq_group_meta[1].token_chunk_size == 28
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 56
def test_concurrent_chunking_large_requests():
"""Verify large prefill requests are run one at a time"""
block_size = 4
max_seqs = 60
max_model_len = 2000
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True,
max_num_partial_prefills=2, # Up to 2 partial prefills at a time
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests
cache_config.num_gpu_blocks = 3200
scheduler = Scheduler(scheduler_config, cache_config, None)
# Add seq groups to scheduler.
for i in range(2):
_, seq_group = create_dummy_prompt(
str(i),
prompt_length=1200, # Very large prompt
block_size=block_size)
scheduler.add_seq_group(seq_group)
# Verify only a single request is chunked, and it gets all 64 tokens
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert len(get_sequence_groups(out)) == 1
assert seq_group_meta[0].token_chunk_size == 64
assert out.num_prefill_groups == 1
assert out.num_batched_tokens == 64
def test_short_prompts_jump_long_prompts_in_queue():
"""Verify large prefill requests are punted behind smaller ones if
another large prefill request is already running"""
block_size = 4
max_seqs = 60
max_model_len = 2000
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True,
max_num_partial_prefills=2, # Up to 2 partial prefills at a time
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests
cache_config.num_gpu_blocks = 3200
scheduler = Scheduler(scheduler_config, cache_config, None)
long_seqs: List[SequenceGroup] = []
short_seqs: List[SequenceGroup] = []
# Add 2 large seq groups to scheduler.
for i in range(2):
_, seq_group = create_dummy_prompt(
str(i),
prompt_length=1200, # Very large prompt
block_size=block_size)
scheduler.add_seq_group(seq_group)
long_seqs.append(seq_group)
assert seq_group.is_prefill()
# Add 2 small seq groups behind them
for i in range(2):
_, seq_group = create_dummy_prompt(
str(i + 2),
prompt_length=40, # Very small prompt
block_size=block_size)
scheduler.add_seq_group(seq_group)
short_seqs.append(seq_group)
assert seq_group.is_prefill()
# Verify one large req and 1 small req chunked
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert seq_group_meta[0].token_chunk_size == 32 # large req gets 32 tokens
assert seq_group_meta[1].token_chunk_size == 32 # small req gets 32 tokens
# all 4 are prefilling
assert long_seqs[0].is_prefill()
assert long_seqs[1].is_prefill()
assert short_seqs[0].is_prefill()
assert short_seqs[1].is_prefill()
# First short and first long sequences have been scheduled
assert long_seqs[0].first_seq.get_num_computed_tokens() == 32
assert long_seqs[1].first_seq.get_num_computed_tokens() == 0
assert short_seqs[0].first_seq.get_num_computed_tokens() == 32
assert short_seqs[1].first_seq.get_num_computed_tokens() == 0
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 64
# in the second iteration,
# the first small request had only 8 tokens left
# so it went to decode
# The other small req is scheduled
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
# the new small req got 64 - (32+8) tokens
assert seq_group_meta[0].token_chunk_size == 24
assert seq_group_meta[1].token_chunk_size == 32 # large req still got 32
# the other small request had only 8 tokens left
assert seq_group_meta[2].token_chunk_size == 8 # 40-32
# The first small request got to decode now
assert long_seqs[0].is_prefill()
assert long_seqs[1].is_prefill()
assert not short_seqs[0].is_prefill()
assert short_seqs[1].is_prefill()
# Both small requests have started in front of the second long request
assert long_seqs[0].first_seq.get_num_computed_tokens() == 64
assert long_seqs[1].first_seq.get_num_computed_tokens() == 0
assert short_seqs[0].first_seq.get_num_computed_tokens() == 40
assert short_seqs[1].first_seq.get_num_computed_tokens() == 24
assert out.num_prefill_groups == 3
assert out.num_batched_tokens == 64
# the first small seq group has a new token appended.
append_new_token(short_seqs[0], 1)
# in the third iteration,
# the first small request is already decoding
# the second small request only has 16 tokens left and will enter decoding
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert seq_group_meta[0].token_chunk_size == 32 # large still got 32
# small req finished prefilling 40-24=16 tokens
assert seq_group_meta[1].token_chunk_size == 16
assert seq_group_meta[2].token_chunk_size == 1 # decode
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 49 # (32+16+1 decode)
# both small requests have now reached decode
assert long_seqs[0].is_prefill()
assert long_seqs[1].is_prefill()
assert not short_seqs[0].is_prefill()
assert not short_seqs[1].is_prefill()
assert long_seqs[0].first_seq.get_num_computed_tokens() == 96
assert long_seqs[1].first_seq.get_num_computed_tokens() == 0
assert short_seqs[0].first_seq.get_num_computed_tokens() == 41
assert short_seqs[1].first_seq.get_num_computed_tokens() == 40
# both the small seq groups have a new token appended
append_new_token(short_seqs[0], 1)
append_new_token(short_seqs[1], 1)
# in the fourth iteration, both small requests are decoding
# so large request gets all the budget
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
# large req gets 62 tokens (minus 2 for decode)
assert seq_group_meta[0].token_chunk_size == 62
assert seq_group_meta[1].token_chunk_size == 1 # decode
assert seq_group_meta[2].token_chunk_size == 1 # decode
assert out.num_prefill_groups == 1
assert out.num_batched_tokens == 64
assert long_seqs[0].first_seq.get_num_computed_tokens() == 158
# assert long_seqs[0].is_prefill()
# assert long_seqs[1].is_prefill()
# assert not short_seqs[0].is_prefill()
# assert not short_seqs[1].is_prefill()
# # both the small seq groups have a new token appended
# append_new_token(short_seqs[0], 1)
# append_new_token(short_seqs[1], 1)
# # in the fifth iteration, large request gets all the budget
# # while both small requests are decoding
# seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
# assert seq_group_meta[0].token_chunk_size == 62
# assert seq_group_meta[1].token_chunk_size == 1 # decode
# assert seq_group_meta[2].token_chunk_size == 1 # decode
# assert out.num_prefill_groups == 1
# assert out.num_batched_tokens == 64
def test_complex():
block_size = 4
max_seqs = 60
@ -508,7 +737,7 @@ def test_chunked_prefill_max_seqs():
assert not running[1].is_prefill()
def test_perfix_caching():
def test_prefix_caching():
"""Verify allocating full blocks when prefix caching is enabled."""
block_size = 4
max_seqs = 10
@ -548,3 +777,86 @@ def test_perfix_caching():
assert seq_group_meta[1].token_chunk_size == 12
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 62
def test_prefix_caching_with_concurrent_partial_prefills():
"""Verify allocating full blocks when prefix caching is enabled with
--max-num-partial-prefills > 1."""
block_size = 4
max_seqs = 10
max_model_len = 8000
max_num_batched_tokens = 60 # With two slots, each slot will get 30 tokens
scheduler_config = SchedulerConfig("generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True,
max_num_partial_prefills=2)
cache_config = CacheConfig(block_size,
1.0,
1,
"auto",
enable_prefix_caching=True)
cache_config.num_cpu_blocks = 0
cache_config.num_gpu_blocks = 32
scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []
# Add seq groups to scheduler.
for i in range(2):
_, seq_group = create_dummy_prompt(str(i),
block_size=block_size,
prompt_length=50)
scheduler.add_seq_group(seq_group)
running.append(seq_group)
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
# To partially prefill both sequences, both can chunk up to 30 tokens
# But the next lowest multiple of the block size (4) is 28
assert seq_group_meta[0].token_chunk_size == 28
assert seq_group_meta[1].token_chunk_size == 28
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 56
# On the next iteration, both sequences should finish prefill
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
# Both sequences have 50 - 28 = 22 tokens left to prefill.
# This is not a multiple of the block size, but we don't care since we don't
# cache the final partial block of prefix sequences
assert seq_group_meta[0].token_chunk_size == 22
assert seq_group_meta[1].token_chunk_size == 22
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 44
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("max_num_partial_prefills", [2, 4, 8])
def test_chunked_prefill_with_actual_engine(model: str,
max_num_partial_prefills: int):
"""Make sure the model can actually sample with concurrent
partial prefills
"""
prompt = "hello" * 40
engine_args = EngineArgs(
model=model,
max_num_partial_prefills=max_num_partial_prefills,
max_num_batched_tokens=40,
max_num_seqs=8,
enable_chunked_prefill=True,
gpu_memory_utilization=0.8,
)
engine = LLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(temperature=0)
for req_num in range(max_num_partial_prefills):
engine.add_request(f"{req_num}", prompt, sampling_params)
# first step
request_outputs = engine.step()
# means all are prefilling
assert len(request_outputs) == 0
assert len(engine.scheduler[0].running) == max_num_partial_prefills

View File

@ -1430,6 +1430,17 @@ class SchedulerConfig:
# Maximum length of a sequence (including prompt and generated text).
max_model_len: int = 8192
# Maximum number of sequences that can be partially prefilled concurrently
max_num_partial_prefills: int = 1
# Maximum number of "very long prompt" sequences that can be prefilled
# concurrently (long is defined by long_prefill_threshold)
max_long_partial_prefills: int = 1
# calculate context length that determines which sequences are
# considered "long"
long_prefill_token_threshold: int = 0
# The number of slots to allocate per sequence per
# step, beyond the known token ids. This is used in speculative
# decoding to store KV activations of tokens which may or may not be
@ -1537,6 +1548,18 @@ class SchedulerConfig:
self.max_num_batched_tokens)
self.chunked_prefill_enabled = self.enable_chunked_prefill
if self.max_num_partial_prefills > 1:
if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.max_model_len *
0.04)
logger.info(
"Concurrent partial prefills enabled with "
"max_num_partial_prefills=%d, max_long_partial_prefills=%d, "
"long_prefill_token_threshold=%d",
self.max_num_partial_prefills, self.max_long_partial_prefills,
self.long_prefill_token_threshold)
self._verify_args()
def _verify_args(self) -> None:
@ -1568,6 +1591,29 @@ class SchedulerConfig:
f"({self.num_scheduler_steps}) must be greater than or "
"equal to 1.")
if self.max_num_partial_prefills < 1:
raise ValueError(
f"max_num_partial_prefills ({self.max_num_partial_prefills}) "
"must be greater than or equal to 1.")
elif self.max_num_partial_prefills > 1:
if not self.chunked_prefill_enabled:
raise ValueError("Chunked prefill must be enabled to set "
"max_num_partial_prefills > 1.")
if self.long_prefill_token_threshold > self.max_model_len:
raise ValueError(
"long_prefill_token_threshold "
f"({self.long_prefill_token_threshold}) cannot be greater "
f"than the max_model_len ({self.max_model_len}).")
if (self.max_long_partial_prefills
< 1) or (self.max_long_partial_prefills
> self.max_num_partial_prefills):
raise ValueError(
f"max_long_partial_prefills ({self.max_long_partial_prefills}) "
"must be greater than or equal to 1 and less than or equal to "
f"max_num_partial_prefills ({self.max_num_partial_prefills}).")
@property
def is_multi_step(self) -> bool:
return self.num_scheduler_steps > 1

View File

@ -17,7 +17,7 @@ from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceGroupMetadataDelta,
SequenceStatus)
SequenceStage, SequenceStatus)
from vllm.utils import Device, PyObjectCache
logger = init_logger(__name__)
@ -39,6 +39,7 @@ class PreemptionMode(enum.Enum):
recompute them when the sequences are resumed, treating the sequences as
new prompts.
"""
SWAP = enum.auto()
RECOMPUTE = enum.auto()
@ -54,6 +55,7 @@ class SchedulingBudget:
happen if we only have chunked prefill scheduling, we can remove this
feature from the API when chunked prefill is enabled by default.
"""
token_budget: int
max_num_seqs: int
_request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
@ -132,6 +134,7 @@ class ScheduledSequenceGroup:
@dataclass
class SchedulerOutputs:
"""The scheduling decision made from a scheduler."""
# Scheduled sequence groups.
scheduled_seq_groups: GenericSequence[ScheduledSequenceGroup]
# Number of prefill groups scheduled.
@ -205,6 +208,7 @@ class SchedulerRunningOutputs:
Could contain prefill (prefill that's chunked) or decodes. If there's not
enough memory, it can be preempted (for recompute) or swapped out.
"""
# Selected sequences that are running and in a decoding phase.
decode_seq_groups: List[ScheduledSequenceGroup]
# Selected sequences that are running and in a prefill phase.
@ -246,6 +250,7 @@ class SchedulerSwappedInOutputs:
Could contain prefill (prefill that's chunked) or decodes.
"""
# Selected sequences that are going to be swapped in and is in a
# decoding phase.
decode_seq_groups: List[ScheduledSequenceGroup]
@ -280,6 +285,7 @@ class SchedulerPrefillOutputs:
Could contain a fresh prefill requests or preempted requests that need
to be recomputed from scratch.
"""
# Selected sequences for prefill.
seq_groups: List[ScheduledSequenceGroup]
# Ignored sequence groups.
@ -321,6 +327,100 @@ def scheduled_seq_group_builder():
# return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
@dataclass
class PartialPrefillMetadata:
"""Holds information about the partial prefills that are currently running
during a single iteration of the Scheduler.
When chunked prefill is enabled, we allow a certain number of seqs to be
partially prefilled during each iteration. Having multiple partial prefills
in flight allows us to minimize TTFT and avoid decode starvation in cases
where a single sequence group with a very large prompt blocks the queue for
too many iterations.
The number of long prefill requests is limited so that smaller
requests may jump the queue in front of them and get to the decode
phase faster.
"""
# A minimum bound on the total number of prefills to be scheduled during
# this iteration
schedulable_prefills: int
# The number of long prefill requests currently running
long_prefills: int
scheduler_config: SchedulerConfig
def can_schedule(self, seq_group: SequenceGroup) -> bool:
"""When concurrent partial prefills are enabled,
we limit the number of long requests and only accept
shorter requests from the queue while running them
concurrently"""
return not (seq_group.first_seq.get_num_new_tokens()
> self.scheduler_config.long_prefill_token_threshold
and self.long_prefills
>= self.scheduler_config.max_long_partial_prefills
and self.scheduler_config.max_num_partial_prefills > 1)
def maybe_increment_partial_prefills(self,
seq_group: SequenceGroup) -> None:
# When a new prefill is scheduled, we need to know if it is a
# long request
if (seq_group.first_seq.get_num_new_tokens()
> self.scheduler_config.long_prefill_token_threshold):
self.long_prefills += 1
@classmethod
def from_queues(
cls,
running: Deque[SequenceGroup],
waiting: Deque[SequenceGroup],
scheduler_config: SchedulerConfig,
) -> "PartialPrefillMetadata":
"""Create a PartialPrefillMetadata object from the current state of
the scheduler's queues.
This accounts for the currently running prefill requests, and peeks into
the waiting queue to see if there are more prefills to potentially be
scheduled during this iteration."""
prefills = 0
long_prefills = 0
waiting_long_prefills = 0
for sg in running:
if sg.first_seq.data.stage == SequenceStage.PREFILL:
prefills += 1
if (sg.first_seq.get_num_new_tokens()
> scheduler_config.long_prefill_token_threshold):
long_prefills += 1
for sg in waiting:
# Don't bother looping through the rest of the queue if we know
# there are already at
# least max_partial_prefills requests to fill
if prefills >= scheduler_config.max_num_partial_prefills:
break
# Don't count long requests from the waiting queue if we aren't
# going to schedule them anyway
if (sg.first_seq.get_num_new_tokens()
> scheduler_config.long_prefill_token_threshold):
if (long_prefills + waiting_long_prefills
>= scheduler_config.max_long_partial_prefills):
continue
waiting_long_prefills += 1
prefills += 1
# NB: long_prefills and waiting_long_prefills are tracked separately.
# We don't account for the waiting requests here because we need to use
# this metadata to track how many have actually been scheduled.
return PartialPrefillMetadata(
schedulable_prefills=min(
prefills, scheduler_config.max_num_partial_prefills),
long_prefills=long_prefills,
scheduler_config=scheduler_config,
)
class Scheduler:
def __init__(
@ -360,7 +460,8 @@ class Scheduler:
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
sliding_window=self.cache_config.sliding_window,
enable_caching=self.cache_config.enable_prefix_caching)
enable_caching=self.cache_config.enable_prefix_caching,
)
# Sequence groups in the WAITING state.
# Contain new prefill or preempted requests.
@ -421,6 +522,18 @@ class Scheduler:
# for processing and deallocation by the free_finished_seq_groups()
self._async_stopped: List[SequenceGroup] = []
# List with the chunk sizes to hand out to each sequence depending
# on how many partial prefills are running. This is slightly faster than
# running an integer division every time a prefill is scheduled.
# This splits the budget evenly among all prefills.
self.partial_prefill_budget_lookup_list = [0] * (
self.scheduler_config.max_num_partial_prefills + 1)
self.partial_prefill_budget_lookup_list[0] = (
scheduler_config.max_num_batched_tokens)
for i in range(1, self.scheduler_config.max_num_partial_prefills + 1):
self.partial_prefill_budget_lookup_list[i] = (
scheduler_config.max_num_batched_tokens // i)
@property
def next_cache_id(self):
return (self.cache_id + 1) % self.num_cache_iters
@ -500,8 +613,8 @@ class Scheduler:
self.block_manager.free_cross(seq_group)
def has_unfinished_seqs(self) -> bool:
return len(self.waiting) != 0 or len(self.running) != 0 or len(
self.swapped) != 0
return (len(self.waiting) != 0 or len(self.running) != 0
or len(self.swapped) != 0)
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device)
@ -523,6 +636,7 @@ class Scheduler:
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
enable_chunking: bool = False,
partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
) -> SchedulerRunningOutputs:
"""Schedule sequence groups that are running.
@ -537,12 +651,14 @@ class Scheduler:
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
partial_prefill_metadata: information about the partial prefills
that are currently running
Returns:
SchedulerRunningOutputs.
"""
ret: SchedulerRunningOutputs = \
self._scheduler_running_outputs_cache[self.cache_id].get_object()
ret: SchedulerRunningOutputs = self._scheduler_running_outputs_cache[
self.cache_id].get_object()
ret.blocks_to_swap_out.clear()
ret.blocks_to_copy.clear()
ret.decode_seq_groups.clear()
@ -577,10 +693,14 @@ class Scheduler:
# 2. If a sequence is running with non-chunked prefill, then
# there it's a decoding sequence, and the cached tokens info is
# irrelevant.
num_uncached_new_tokens, _ = (
num_uncached_new_tokens, _ = \
self._get_num_new_uncached_and_cached_tokens(
seq_group, SequenceStatus.RUNNING, enable_chunking,
budget))
seq_group,
SequenceStatus.RUNNING,
enable_chunking,
budget,
partial_prefill_metadata,
)
num_running_tokens = num_uncached_new_tokens
if num_running_tokens == 0:
@ -593,8 +713,8 @@ class Scheduler:
# to process the final tokens. The check below avoids this extra
# decode run when the model max len is reached, in order to avoid
# a memory overflow.
if self.use_async_output_proc and seq_group.seqs[0].get_len(
) > self.scheduler_config.max_model_len:
if (self.use_async_output_proc and seq_group.seqs[0].get_len()
> self.scheduler_config.max_model_len):
self._async_stopped.append(seq_group)
continue
@ -653,8 +773,9 @@ class Scheduler:
self._append_slots(seq_group, blocks_to_copy, enable_chunking)
is_prefill = seq_group.is_prefill()
scheduled_seq_group: ScheduledSequenceGroup = \
self._scheduled_seq_group_cache[self.cache_id].get_object()
scheduled_seq_group: ScheduledSequenceGroup = (
self._scheduled_seq_group_cache[
self.cache_id].get_object())
scheduled_seq_group.seq_group = seq_group
if is_prefill:
scheduled_seq_group.token_chunk_size = num_running_tokens
@ -731,7 +852,8 @@ class Scheduler:
logger.warning(
"Failing the request %s because there's not enough kv "
"cache blocks to run the entire sequence.",
seq_group.request_id)
seq_group.request_id,
)
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_IGNORED
infeasible_seq_groups.append(seq_group)
@ -800,16 +922,17 @@ class Scheduler:
)
def _get_prompt_limit(self, seq_group: SequenceGroup) -> int:
if self.scheduler_config.chunked_prefill_enabled and \
not self.scheduler_config.is_multi_step:
if (self.scheduler_config.chunked_prefill_enabled
and not self.scheduler_config.is_multi_step):
prompt_limit = self.scheduler_config.max_model_len
else:
prompt_limit = min(self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens)
prompt_limit = min(
self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens,
)
# Model is fine tuned with long context. Return the fine tuned max_len.
if (seq_group.lora_request
and seq_group.lora_request.long_lora_max_len):
if seq_group.lora_request and seq_group.lora_request.long_lora_max_len:
assert prompt_limit <= seq_group.lora_request.long_lora_max_len
return seq_group.lora_request.long_lora_max_len
else:
@ -817,7 +940,7 @@ class Scheduler:
def _get_priority(self,
seq_group: SequenceGroup) -> Tuple[Optional[int], float]:
""" Get the priority of the sequence group.
"""Get the priority of the sequence group.
Highest preference to user-defined priority, followed by arrival time.
Args:
seq_group: The sequence group input.
@ -850,14 +973,14 @@ class Scheduler:
if waiting_queue:
seq_group = waiting_queue.popleft()
num_new_seqs = seq_group.get_max_num_running_seqs()
num_new_tokens_uncached, _ = (
num_new_tokens_uncached, _ = \
self._get_num_new_uncached_and_cached_tokens(
seq_group, SequenceStatus.WAITING, False, budget))
seq_group, SequenceStatus.WAITING, False, budget)
#Only preempt if priority inversion exists
# Only preempt if priority inversion exists
while running_queue and self._get_priority(
running_queue[-1]) > self._get_priority(seq_group):
#Only preempt if waiting sequence cannot be allocated
# Only preempt if waiting sequence cannot be allocated
can_allocate = self.block_manager.can_allocate(seq_group)
if (num_new_tokens_uncached > 0
and can_allocate == AllocStatus.OK
@ -867,7 +990,7 @@ class Scheduler:
)):
break
#Adjust budget to remove the victim sequence group
# Adjust budget to remove the victim sequence group
vseq_group = running_queue.pop()
num_running_tokens_uncached, _ = (
self._get_num_new_uncached_and_cached_tokens(
@ -878,11 +1001,11 @@ class Scheduler:
budget.subtract_num_seqs(vseq_group.request_id,
num_running_seqs)
#Preempt out the victim sequence group
# Preempt out the victim sequence group
self._preempt(vseq_group, blocks_to_swap_out)
waiting_queue.appendleft(vseq_group)
force_preemption_count += 1
#Put the sequence back into the waiting queue
# Put the sequence back into the waiting queue
waiting_queue.appendleft(seq_group)
waiting_queue = deque(sorted(waiting_queue, key=self._get_priority))
@ -896,6 +1019,7 @@ class Scheduler:
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
enable_chunking: bool = False,
partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
) -> SchedulerPrefillOutputs:
"""Schedule sequence groups that are in prefill stage.
@ -916,10 +1040,20 @@ class Scheduler:
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
partial_prefill_metadata: information about the partial prefills
that are currently running
Returns:
SchedulerPrefillOutputs.
"""
if budget.remaining_token_budget() == 0:
# Do nothing: Can't add any more prefill anyway
return SchedulerPrefillOutputs(
seq_groups=[],
ignored_seq_groups=[],
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=True, enable_chunking=enable_chunking),
)
ignored_seq_groups: List[SequenceGroup] = []
seq_groups: List[ScheduledSequenceGroup] = []
@ -933,10 +1067,19 @@ class Scheduler:
assert len(waiting_seqs) == 1, (
"Waiting sequence group should have only one prompt "
"sequence.")
if (partial_prefill_metadata is not None
and not partial_prefill_metadata.can_schedule(seq_group)):
leftover_waiting_sequences.appendleft(seq_group)
waiting_queue.popleft()
continue
num_new_tokens_uncached, num_new_tokens_cached = (
self._get_num_new_uncached_and_cached_tokens(
seq_group, SequenceStatus.WAITING, enable_chunking,
budget))
seq_group,
SequenceStatus.WAITING,
enable_chunking,
budget,
partial_prefill_metadata=partial_prefill_metadata,
))
num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached
if not enable_chunking:
@ -947,7 +1090,10 @@ class Scheduler:
if num_new_tokens > prompt_limit:
logger.warning(
"Input prompt (%d tokens) is too long"
" and exceeds limit of %d", num_new_tokens, prompt_limit)
" and exceeds limit of %d",
num_new_tokens,
prompt_limit,
)
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
@ -968,7 +1114,9 @@ class Scheduler:
logger.warning(
"Input prompt (%d tokens) + lookahead slots (%d) is "
"too long and exceeds the capacity of block_manager",
num_new_tokens, num_lookahead_slots)
num_new_tokens,
num_lookahead_slots,
)
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
@ -1009,6 +1157,10 @@ class Scheduler:
waiting_queue.popleft()
self._allocate_and_set_running(seq_group)
if partial_prefill_metadata is not None:
partial_prefill_metadata.maybe_increment_partial_prefills(
seq_group)
if enable_chunking and self.scheduler_config.is_multi_step:
blocks_to_copy: List[Tuple[int, int]] = []
# init_multi_step_from_lookahead_slots happens in append_slots
@ -1024,7 +1176,8 @@ class Scheduler:
num_scheduler_steps=self.scheduler_config.
num_scheduler_steps,
is_multi_step=self.scheduler_config.is_multi_step,
enable_chunking=enable_chunking)
enable_chunking=enable_chunking,
)
seq_groups.append(
ScheduledSequenceGroup(seq_group=seq_group,
@ -1045,11 +1198,12 @@ class Scheduler:
seq_groups=seq_groups,
ignored_seq_groups=ignored_seq_groups,
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=True, enable_chunking=enable_chunking))
is_prefill=True, enable_chunking=enable_chunking),
)
def _schedule_default(self) -> SchedulerOutputs:
"""Schedule queued requests.
The current policy is designed to optimize the throughput. First,
it batches as many prefill requests as possible. And it schedules
decodes. If there's a pressure on GPU memory, decode requests can
@ -1065,9 +1219,9 @@ class Scheduler:
for seq_group in self.running:
budget.add_num_seqs(seq_group.request_id,
seq_group.get_max_num_running_seqs())
curr_loras = set(
curr_loras = (set(
seq_group.lora_int_id for seq_group in self.running
if seq_group.lora_int_id > 0) if self.lora_enabled else None
if seq_group.lora_int_id > 0) if self.lora_enabled else None)
prefills = SchedulerPrefillOutputs.create_empty()
running_scheduled = SchedulerRunningOutputs.create_empty()
@ -1093,9 +1247,10 @@ class Scheduler:
# If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests.
if len(running_scheduled.preempted) + len(
running_scheduled.swapped_out) == 0:
swapped_in = self._schedule_swapped(budget, curr_loras)
if (len(running_scheduled.preempted) +
len(running_scheduled.swapped_out) == 0):
swapped_in = \
self._schedule_swapped(budget, curr_loras)
assert (budget.num_batched_tokens
<= self.scheduler_config.max_num_batched_tokens)
@ -1115,8 +1270,8 @@ class Scheduler:
# Update swapped requests.
self.swapped.extend(running_scheduled.swapped_out)
preempted = (len(running_scheduled.preempted) +
len(running_scheduled.swapped_out))
preempted = len(running_scheduled.preempted) + len(
running_scheduled.swapped_out)
# There should be no prefill from running queue because this policy
# doesn't allow chunked prefills.
@ -1154,7 +1309,7 @@ class Scheduler:
def _schedule_chunked_prefill(self) -> SchedulerOutputs:
"""Schedule queued requests.
Chunked prefill allows to chunk prefill requests, batch them together
with decode requests. This policy 1. schedule as many decoding requests
as possible. 2. schedule chunked prefill requests that are not
@ -1175,10 +1330,20 @@ class Scheduler:
prefills = SchedulerPrefillOutputs.create_empty()
swapped_in = SchedulerSwappedInOutputs.create_empty()
# Create partial prefill metadata
partial_prefill_metadata = PartialPrefillMetadata.from_queues(
running=self.running,
waiting=self.waiting,
scheduler_config=self.scheduler_config,
)
# Decoding should be always scheduled first by fcfs.
running_scheduled = self._schedule_running(budget,
curr_loras,
enable_chunking=True)
running_scheduled = self._schedule_running(
budget,
curr_loras,
enable_chunking=True,
partial_prefill_metadata=partial_prefill_metadata,
)
# Schedule swapped out requests.
# If preemption happens, it means we don't have space for swap-in.
@ -1186,9 +1351,12 @@ class Scheduler:
running_scheduled.swapped_out) == 0:
swapped_in = self._schedule_swapped(budget, curr_loras)
prefills = self._schedule_prefills(budget,
curr_loras,
enable_chunking=True)
prefills = self._schedule_prefills(
budget,
curr_loras,
enable_chunking=True,
partial_prefill_metadata=partial_prefill_metadata,
)
assert (budget.num_batched_tokens
<= self.scheduler_config.max_num_batched_tokens)
@ -1207,8 +1375,15 @@ class Scheduler:
[s.seq_group for s in swapped_in.prefill_seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.decode_seq_groups])
# Because multiple prefills may be running concurrently, we need to
# make sure that prefills which are scheduled to finish are listed
# before those that won't. This is so that on the next scheduling
# iteration when they have transitioned to the decode stage, they are
# properly prioritized over sequences that are still in the prefill
# stage.
self.running.extend(
[s.seq_group for s in running_scheduled.prefill_seq_groups])
self._order_finishing_prefills_first(
running_scheduled.prefill_seq_groups))
self.running.extend([s.seq_group for s in prefills.seq_groups])
# Update swapped requests.
@ -1225,7 +1400,7 @@ class Scheduler:
# If all prompts, then we set num_lookahead_slots to 0
# this allows us to go through the `no_spec` path in
# `spec_decode_worker.py`
all_prefills = (len(scheduled_seq_groups) == num_prefill_groups)
all_prefills = len(scheduled_seq_groups) == num_prefill_groups
num_lookahead_slots = (0 if
(all_prefills
and not self.scheduler_config.is_multi_step)
@ -1247,6 +1422,21 @@ class Scheduler:
len(running_scheduled.swapped_out)),
)
def _order_finishing_prefills_first(
self, scheduled_prefill_seqs: List[ScheduledSequenceGroup]
) -> List[SequenceGroup]:
"""Returns a list of prefilling SequenceGroups where sequences that are
scheduled to finish prefilling are listed first"""
finishing = [
s.seq_group for s in scheduled_prefill_seqs
if s.seq_group.get_num_uncomputed_tokens() == s.token_chunk_size
]
not_finishing = [
s.seq_group for s in scheduled_prefill_seqs
if s.seq_group.get_num_uncomputed_tokens() != s.token_chunk_size
]
return finishing + not_finishing
def _schedule(self) -> SchedulerOutputs:
"""Schedule queued requests."""
if self.scheduler_config.chunked_prefill_enabled:
@ -1385,10 +1575,12 @@ class Scheduler:
# between engine and worker.
# the subsequent comms can still use delta, but
# `multi_modal_data` will be None.
multi_modal_data=seq_group.multi_modal_data
if scheduler_outputs.num_prefill_groups > 0 else None,
multi_modal_placeholders=seq_group.multi_modal_placeholders
if scheduler_outputs.num_prefill_groups > 0 else None,
multi_modal_data=(seq_group.multi_modal_data
if scheduler_outputs.num_prefill_groups
> 0 else None),
multi_modal_placeholders=(
seq_group.multi_modal_placeholders
if scheduler_outputs.num_prefill_groups > 0 else None),
mm_processor_kwargs=seq_group.mm_processor_kwargs,
prompt_adapter_request=seq_group.prompt_adapter_request,
)
@ -1494,10 +1686,12 @@ class Scheduler:
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
seq.status = SequenceStatus.RUNNING
def _append_slots(self,
seq_group: SequenceGroup,
blocks_to_copy: List[Tuple[int, int]],
enable_chunking: bool = False) -> None:
def _append_slots(
self,
seq_group: SequenceGroup,
blocks_to_copy: List[Tuple[int, int]],
enable_chunking: bool = False,
) -> None:
"""Appends new slots to the sequences in the given sequence group.
Args:
@ -1518,7 +1712,8 @@ class Scheduler:
num_lookahead_slots,
num_scheduler_steps=self.scheduler_config.num_scheduler_steps,
is_multi_step=self.scheduler_config.is_multi_step,
enable_chunking=enable_chunking)
enable_chunking=enable_chunking,
)
seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING
if self.scheduler_config.is_multi_step and enable_chunking:
@ -1561,8 +1756,11 @@ class Scheduler:
"not enough KV cache space. This can affect the end-to-end "
"performance. Increase gpu_memory_utilization or "
"tensor_parallel_size to provide more KV cache memory. "
"total_num_cumulative_preemption=%d", seq_group.request_id,
preemption_mode, self.num_cumulative_preemption + 1)
"total_num_cumulative_preemption=%d",
seq_group.request_id,
preemption_mode,
self.num_cumulative_preemption + 1,
)
self.num_cumulative_preemption += 1
if preemption_mode == PreemptionMode.RECOMPUTE:
@ -1668,6 +1866,7 @@ class Scheduler:
status: SequenceStatus,
enable_chunking: bool,
budget: SchedulingBudget,
partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
) -> Tuple[int, int]:
"""
Returns the number of new uncached and cached tokens to schedule for a
@ -1691,6 +1890,8 @@ class Scheduler:
to schedule.
enable_chunking: Whether to chunk the number of tokens to compute.
budget: The budget to chunk the number of tokens to compute.
partial_prefill_metadata: information about the partial prefills
that are currently running
Returns:
@ -1768,6 +1969,8 @@ class Scheduler:
budget,
self._get_prompt_limit(seq_group),
num_uncached_new_tokens,
self.partial_prefill_budget_lookup_list,
partial_prefill_metadata,
)
return num_uncached_new_tokens, num_cached_new_tokens
@ -1779,6 +1982,8 @@ class Scheduler:
budget: SchedulingBudget,
prompt_limit: int,
num_new_tokens: int,
partial_prefill_budget_lookup_list: List[int],
partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
) -> int:
"""
Chunks the number of new tokens to schedule based on the budget when
@ -1811,29 +2016,31 @@ class Scheduler:
# the sequence.
return num_new_tokens
return (0 if num_new_tokens > remaining_token_budget else
num_new_tokens)
return 0 if num_new_tokens > \
remaining_token_budget else num_new_tokens
# Get the number of tokens to allocate to this prefill slot
prefill_slot_budget = (
remaining_token_budget if partial_prefill_metadata is None else
partial_prefill_budget_lookup_list[
partial_prefill_metadata.schedulable_prefills])
if cache_config.enable_prefix_caching:
# Adjust the remaining token budget to be divisible by the block
# size when prefix caching is enabled.
# When prefix caching is enabled, we always allocate
# the number of new tokens that is dividable by the block
# size to avoid partial block matching.
# When prefix caching is enabled and we're partially prefilling
# a sequence, we always allocate a number of new tokens that is
# divisible by the block size to avoid partial block matching.
block_size = cache_config.block_size
remainder = budget.token_budget % block_size
if remainder != 0:
raise ValueError("When enabling chunked prefill and "
"prefix caching, max_num_batched_tokens "
"(chunk size) must be dividable by "
"block size, but got chunk_size "
f"({budget.token_budget}) % block_size "
f"({block_size}) = {remainder}")
# Round down to block size.
remaining_token_budget = (remaining_token_budget // block_size *
block_size)
# Don't exceed either the total budget or slot budget.
# Take min of those and get the next lowest multiple of the
# block size:
remaining_token_budget = (
min(remaining_token_budget, prefill_slot_budget) //
block_size) * block_size
# NB: In the case where num_new_tokens < budget, we are
# finishing prefill for this sequence, so we do not need to
# allocate a full block.
num_new_tokens = min(num_new_tokens, remaining_token_budget)
num_new_tokens = min(num_new_tokens, remaining_token_budget,
prefill_slot_budget)
return num_new_tokens

View File

@ -120,6 +120,9 @@ class EngineArgs:
cpu_offload_gb: float = 0 # GiB
gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None
max_num_partial_prefills: Optional[int] = 1
max_long_partial_prefills: Optional[int] = 1
long_prefill_token_threshold: Optional[int] = 0
max_num_seqs: Optional[int] = None
max_logprobs: int = 20 # Default value for OpenAI Chat Completions API
disable_log_stats: bool = False
@ -515,6 +518,31 @@ class EngineArgs:
default=EngineArgs.max_num_batched_tokens,
help='Maximum number of batched tokens per '
'iteration.')
parser.add_argument(
"--max-num-partial-prefills",
type=int,
default=EngineArgs.max_num_partial_prefills,
help="For chunked prefill, the max number of concurrent \
partial prefills."
"Defaults to 1",
)
parser.add_argument(
"--max-long-partial-prefills",
type=int,
default=EngineArgs.max_long_partial_prefills,
help="For chunked prefill, the maximum number of prompts longer "
"than --long-prefill-token-threshold that will be prefilled "
"concurrently. Setting this less than --max-num-partial-prefills "
"will allow shorter prompts to jump the queue in front of longer "
"prompts in some cases, improving latency. Defaults to 1.")
parser.add_argument(
"--long-prefill-token-threshold",
type=float,
default=EngineArgs.long_prefill_token_threshold,
help="For chunked prefill, a request is considered long if the "
"prompt is longer than this number of tokens. Defaults to 4%% of "
"the model's context length.",
)
parser.add_argument('--max-num-seqs',
type=int,
default=EngineArgs.max_num_seqs,
@ -1244,7 +1272,11 @@ class EngineArgs:
multi_step_stream_outputs=self.multi_step_stream_outputs,
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
and parallel_config.use_ray),
policy=self.scheduling_policy)
policy=self.scheduling_policy,
max_num_partial_prefills=self.max_num_partial_prefills,
max_long_partial_prefills=self.max_long_partial_prefills,
long_prefill_token_threshold=self.long_prefill_token_threshold,
)
lora_config = LoRAConfig(
bias_enabled=self.enable_lora_bias,
max_lora_rank=self.max_lora_rank,

View File

@ -958,7 +958,9 @@ def get_logprobs(
if len(query_indices) == 0:
empty_sampled_logprob: SampleLogprobs = []
empty_prompt_logprob: Optional[PromptLogprobs] = None
return [empty_prompt_logprob], [empty_sampled_logprob]
num_seq_groups = len(sampling_metadata.seq_groups)
return [empty_prompt_logprob
] * num_seq_groups, [empty_sampled_logprob] * num_seq_groups
selected_logprobs, ranks = None, None
top_logprobs, top_token_ids = None, None
@ -1225,6 +1227,10 @@ def _build_sampler_output(
assert sample_logprobs is not None
assert not isinstance(maybe_deferred_sample_results,
SampleResultArgsType)
assert len(sampling_metadata.seq_groups) \
== len(maybe_deferred_sample_results) \
== len(prompt_logprobs) \
== len(sample_logprobs)
deferred_sample_results_args = None
for (seq_group, sample_result, group_prompt_logprobs,