mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Core] Reduce TTFT with concurrent partial prefills (#10235)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com> Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com> Co-authored-by: Prashant Gupta <prashantgupta@us.ibm.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
@ -8,7 +8,6 @@ prefill requests are chunked.
|
||||
Run `pytest tests/models/test_chunked_prefill.py`.
|
||||
"""
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
|
||||
import pytest
|
||||
|
||||
@ -233,7 +232,6 @@ def test_with_prefix_caching(
|
||||
|
||||
max_num_batched_tokens = max_num_seqs = chunk_size
|
||||
outputs = {} # type: ignore
|
||||
check_result = True
|
||||
for enable in (True, False):
|
||||
with vllm_runner(
|
||||
model,
|
||||
@ -245,25 +243,17 @@ def test_with_prefix_caching(
|
||||
enforce_eager=enforce_eager,
|
||||
max_num_seqs=max_num_seqs,
|
||||
) as vllm_model:
|
||||
# It should fail when prefix caching is enable and chunk
|
||||
# size is not a multiple of block size (16).
|
||||
should_fail = chunk_size % 16 != 0 and enable
|
||||
check_result &= not should_fail
|
||||
outputs[enable] = []
|
||||
# Send the request one-by-one to ensure the cache is populated.
|
||||
with pytest.raises(ValueError) if should_fail else nullcontext():
|
||||
for prompt in full_prompts:
|
||||
outputs[enable] += vllm_model.generate_greedy([prompt],
|
||||
max_tokens)
|
||||
for prompt in full_prompts:
|
||||
outputs[enable] += vllm_model.generate_greedy([prompt],
|
||||
max_tokens)
|
||||
|
||||
# Check results only if we did not expect a failure.
|
||||
if check_result:
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=outputs[False],
|
||||
outputs_1_lst=outputs[True],
|
||||
name_0="w/o prefix caching",
|
||||
name_1="with prefix caching",
|
||||
)
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=outputs[False],
|
||||
outputs_1_lst=outputs[True],
|
||||
name_0="w/o prefix caching",
|
||||
name_1="with prefix caching",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||
|
@ -7,6 +7,9 @@ import pytest # noqa
|
||||
|
||||
from vllm.config import CacheConfig, SchedulerConfig
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import Logprob, SequenceGroup
|
||||
|
||||
from .utils import create_dummy_prompt
|
||||
@ -16,7 +19,7 @@ def get_sequence_groups(scheduler_output):
|
||||
return [s.seq_group for s in scheduler_output.scheduled_seq_groups]
|
||||
|
||||
|
||||
def append_new_token(seq_group, token_id: int):
|
||||
def append_new_token(seq_group: SequenceGroup, token_id: int):
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
||||
|
||||
@ -123,6 +126,232 @@ def test_chunk():
|
||||
assert out.num_batched_tokens == 57
|
||||
|
||||
|
||||
def test_concurrent_chunking():
|
||||
"""Verify prefills are chunked properly when
|
||||
--max-num-partial-prefills is > 1"""
|
||||
block_size = 4
|
||||
max_seqs = 60
|
||||
max_model_len = 2000
|
||||
max_num_batched_tokens = 64
|
||||
scheduler_config = SchedulerConfig(
|
||||
"generate",
|
||||
max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_partial_prefills=2, # Up to 2 partial prefills at a time
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 32
|
||||
cache_config.num_gpu_blocks = 32
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
running: List[SequenceGroup] = []
|
||||
|
||||
# Add seq groups to scheduler.
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i),
|
||||
prompt_length=60,
|
||||
block_size=block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
|
||||
# Verify both requests are chunked with half of max_num_batched_tokens each
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
assert seq_group_meta[0].token_chunk_size == 32
|
||||
assert seq_group_meta[1].token_chunk_size == 32
|
||||
assert out.num_prefill_groups == 2
|
||||
assert out.num_batched_tokens == 64
|
||||
|
||||
# After one iteration, both should have 60 - 32 = 28 tokens left to prefill
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
assert seq_group_meta[0].token_chunk_size == 28
|
||||
assert seq_group_meta[1].token_chunk_size == 28
|
||||
assert out.num_prefill_groups == 2
|
||||
assert out.num_batched_tokens == 56
|
||||
|
||||
|
||||
def test_concurrent_chunking_large_requests():
|
||||
"""Verify large prefill requests are run one at a time"""
|
||||
block_size = 4
|
||||
max_seqs = 60
|
||||
max_model_len = 2000
|
||||
max_num_batched_tokens = 64
|
||||
scheduler_config = SchedulerConfig(
|
||||
"generate",
|
||||
max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_partial_prefills=2, # Up to 2 partial prefills at a time
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests
|
||||
cache_config.num_gpu_blocks = 3200
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
|
||||
# Add seq groups to scheduler.
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(
|
||||
str(i),
|
||||
prompt_length=1200, # Very large prompt
|
||||
block_size=block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
|
||||
# Verify only a single request is chunked, and it gets all 64 tokens
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(get_sequence_groups(out)) == 1
|
||||
assert seq_group_meta[0].token_chunk_size == 64
|
||||
assert out.num_prefill_groups == 1
|
||||
assert out.num_batched_tokens == 64
|
||||
|
||||
|
||||
def test_short_prompts_jump_long_prompts_in_queue():
|
||||
"""Verify large prefill requests are punted behind smaller ones if
|
||||
another large prefill request is already running"""
|
||||
block_size = 4
|
||||
max_seqs = 60
|
||||
max_model_len = 2000
|
||||
max_num_batched_tokens = 64
|
||||
scheduler_config = SchedulerConfig(
|
||||
"generate",
|
||||
max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_partial_prefills=2, # Up to 2 partial prefills at a time
|
||||
)
|
||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||
cache_config.num_cpu_blocks = 3200 # large KV cache size for large requests
|
||||
cache_config.num_gpu_blocks = 3200
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
long_seqs: List[SequenceGroup] = []
|
||||
short_seqs: List[SequenceGroup] = []
|
||||
|
||||
# Add 2 large seq groups to scheduler.
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(
|
||||
str(i),
|
||||
prompt_length=1200, # Very large prompt
|
||||
block_size=block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
long_seqs.append(seq_group)
|
||||
assert seq_group.is_prefill()
|
||||
|
||||
# Add 2 small seq groups behind them
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(
|
||||
str(i + 2),
|
||||
prompt_length=40, # Very small prompt
|
||||
block_size=block_size)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
short_seqs.append(seq_group)
|
||||
assert seq_group.is_prefill()
|
||||
|
||||
# Verify one large req and 1 small req chunked
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert seq_group_meta[0].token_chunk_size == 32 # large req gets 32 tokens
|
||||
assert seq_group_meta[1].token_chunk_size == 32 # small req gets 32 tokens
|
||||
|
||||
# all 4 are prefilling
|
||||
assert long_seqs[0].is_prefill()
|
||||
assert long_seqs[1].is_prefill()
|
||||
assert short_seqs[0].is_prefill()
|
||||
assert short_seqs[1].is_prefill()
|
||||
# First short and first long sequences have been scheduled
|
||||
assert long_seqs[0].first_seq.get_num_computed_tokens() == 32
|
||||
assert long_seqs[1].first_seq.get_num_computed_tokens() == 0
|
||||
assert short_seqs[0].first_seq.get_num_computed_tokens() == 32
|
||||
assert short_seqs[1].first_seq.get_num_computed_tokens() == 0
|
||||
|
||||
assert out.num_prefill_groups == 2
|
||||
assert out.num_batched_tokens == 64
|
||||
|
||||
# in the second iteration,
|
||||
# the first small request had only 8 tokens left
|
||||
# so it went to decode
|
||||
# The other small req is scheduled
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
# the new small req got 64 - (32+8) tokens
|
||||
assert seq_group_meta[0].token_chunk_size == 24
|
||||
assert seq_group_meta[1].token_chunk_size == 32 # large req still got 32
|
||||
# the other small request had only 8 tokens left
|
||||
assert seq_group_meta[2].token_chunk_size == 8 # 40-32
|
||||
|
||||
# The first small request got to decode now
|
||||
assert long_seqs[0].is_prefill()
|
||||
assert long_seqs[1].is_prefill()
|
||||
assert not short_seqs[0].is_prefill()
|
||||
assert short_seqs[1].is_prefill()
|
||||
# Both small requests have started in front of the second long request
|
||||
assert long_seqs[0].first_seq.get_num_computed_tokens() == 64
|
||||
assert long_seqs[1].first_seq.get_num_computed_tokens() == 0
|
||||
assert short_seqs[0].first_seq.get_num_computed_tokens() == 40
|
||||
assert short_seqs[1].first_seq.get_num_computed_tokens() == 24
|
||||
|
||||
assert out.num_prefill_groups == 3
|
||||
assert out.num_batched_tokens == 64
|
||||
# the first small seq group has a new token appended.
|
||||
append_new_token(short_seqs[0], 1)
|
||||
|
||||
# in the third iteration,
|
||||
# the first small request is already decoding
|
||||
# the second small request only has 16 tokens left and will enter decoding
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert seq_group_meta[0].token_chunk_size == 32 # large still got 32
|
||||
# small req finished prefilling 40-24=16 tokens
|
||||
assert seq_group_meta[1].token_chunk_size == 16
|
||||
assert seq_group_meta[2].token_chunk_size == 1 # decode
|
||||
assert out.num_prefill_groups == 2
|
||||
assert out.num_batched_tokens == 49 # (32+16+1 decode)
|
||||
|
||||
# both small requests have now reached decode
|
||||
assert long_seqs[0].is_prefill()
|
||||
assert long_seqs[1].is_prefill()
|
||||
assert not short_seqs[0].is_prefill()
|
||||
assert not short_seqs[1].is_prefill()
|
||||
assert long_seqs[0].first_seq.get_num_computed_tokens() == 96
|
||||
assert long_seqs[1].first_seq.get_num_computed_tokens() == 0
|
||||
assert short_seqs[0].first_seq.get_num_computed_tokens() == 41
|
||||
assert short_seqs[1].first_seq.get_num_computed_tokens() == 40
|
||||
|
||||
# both the small seq groups have a new token appended
|
||||
append_new_token(short_seqs[0], 1)
|
||||
append_new_token(short_seqs[1], 1)
|
||||
|
||||
# in the fourth iteration, both small requests are decoding
|
||||
# so large request gets all the budget
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
|
||||
# large req gets 62 tokens (minus 2 for decode)
|
||||
assert seq_group_meta[0].token_chunk_size == 62
|
||||
assert seq_group_meta[1].token_chunk_size == 1 # decode
|
||||
assert seq_group_meta[2].token_chunk_size == 1 # decode
|
||||
assert out.num_prefill_groups == 1
|
||||
assert out.num_batched_tokens == 64
|
||||
|
||||
assert long_seqs[0].first_seq.get_num_computed_tokens() == 158
|
||||
|
||||
# assert long_seqs[0].is_prefill()
|
||||
# assert long_seqs[1].is_prefill()
|
||||
# assert not short_seqs[0].is_prefill()
|
||||
# assert not short_seqs[1].is_prefill()
|
||||
|
||||
# # both the small seq groups have a new token appended
|
||||
# append_new_token(short_seqs[0], 1)
|
||||
# append_new_token(short_seqs[1], 1)
|
||||
|
||||
# # in the fifth iteration, large request gets all the budget
|
||||
# # while both small requests are decoding
|
||||
# seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
# assert seq_group_meta[0].token_chunk_size == 62
|
||||
# assert seq_group_meta[1].token_chunk_size == 1 # decode
|
||||
# assert seq_group_meta[2].token_chunk_size == 1 # decode
|
||||
# assert out.num_prefill_groups == 1
|
||||
# assert out.num_batched_tokens == 64
|
||||
|
||||
|
||||
def test_complex():
|
||||
block_size = 4
|
||||
max_seqs = 60
|
||||
@ -508,7 +737,7 @@ def test_chunked_prefill_max_seqs():
|
||||
assert not running[1].is_prefill()
|
||||
|
||||
|
||||
def test_perfix_caching():
|
||||
def test_prefix_caching():
|
||||
"""Verify allocating full blocks when prefix caching is enabled."""
|
||||
block_size = 4
|
||||
max_seqs = 10
|
||||
@ -548,3 +777,86 @@ def test_perfix_caching():
|
||||
assert seq_group_meta[1].token_chunk_size == 12
|
||||
assert out.num_prefill_groups == 2
|
||||
assert out.num_batched_tokens == 62
|
||||
|
||||
|
||||
def test_prefix_caching_with_concurrent_partial_prefills():
|
||||
"""Verify allocating full blocks when prefix caching is enabled with
|
||||
--max-num-partial-prefills > 1."""
|
||||
block_size = 4
|
||||
max_seqs = 10
|
||||
max_model_len = 8000
|
||||
max_num_batched_tokens = 60 # With two slots, each slot will get 30 tokens
|
||||
scheduler_config = SchedulerConfig("generate",
|
||||
max_num_batched_tokens,
|
||||
max_seqs,
|
||||
max_model_len,
|
||||
enable_chunked_prefill=True,
|
||||
max_num_partial_prefills=2)
|
||||
cache_config = CacheConfig(block_size,
|
||||
1.0,
|
||||
1,
|
||||
"auto",
|
||||
enable_prefix_caching=True)
|
||||
cache_config.num_cpu_blocks = 0
|
||||
cache_config.num_gpu_blocks = 32
|
||||
scheduler = Scheduler(scheduler_config, cache_config, None)
|
||||
running: List[SequenceGroup] = []
|
||||
|
||||
# Add seq groups to scheduler.
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i),
|
||||
block_size=block_size,
|
||||
prompt_length=50)
|
||||
scheduler.add_seq_group(seq_group)
|
||||
running.append(seq_group)
|
||||
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
# To partially prefill both sequences, both can chunk up to 30 tokens
|
||||
# But the next lowest multiple of the block size (4) is 28
|
||||
assert seq_group_meta[0].token_chunk_size == 28
|
||||
assert seq_group_meta[1].token_chunk_size == 28
|
||||
assert out.num_prefill_groups == 2
|
||||
assert out.num_batched_tokens == 56
|
||||
|
||||
# On the next iteration, both sequences should finish prefill
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
# Both sequences have 50 - 28 = 22 tokens left to prefill.
|
||||
# This is not a multiple of the block size, but we don't care since we don't
|
||||
# cache the final partial block of prefix sequences
|
||||
assert seq_group_meta[0].token_chunk_size == 22
|
||||
assert seq_group_meta[1].token_chunk_size == 22
|
||||
assert out.num_prefill_groups == 2
|
||||
assert out.num_batched_tokens == 44
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||
@pytest.mark.parametrize("max_num_partial_prefills", [2, 4, 8])
|
||||
def test_chunked_prefill_with_actual_engine(model: str,
|
||||
max_num_partial_prefills: int):
|
||||
"""Make sure the model can actually sample with concurrent
|
||||
partial prefills
|
||||
"""
|
||||
|
||||
prompt = "hello" * 40
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
max_num_partial_prefills=max_num_partial_prefills,
|
||||
max_num_batched_tokens=40,
|
||||
max_num_seqs=8,
|
||||
enable_chunked_prefill=True,
|
||||
gpu_memory_utilization=0.8,
|
||||
)
|
||||
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
|
||||
for req_num in range(max_num_partial_prefills):
|
||||
engine.add_request(f"{req_num}", prompt, sampling_params)
|
||||
# first step
|
||||
request_outputs = engine.step()
|
||||
# means all are prefilling
|
||||
assert len(request_outputs) == 0
|
||||
assert len(engine.scheduler[0].running) == max_num_partial_prefills
|
||||
|
@ -1430,6 +1430,17 @@ class SchedulerConfig:
|
||||
# Maximum length of a sequence (including prompt and generated text).
|
||||
max_model_len: int = 8192
|
||||
|
||||
# Maximum number of sequences that can be partially prefilled concurrently
|
||||
max_num_partial_prefills: int = 1
|
||||
|
||||
# Maximum number of "very long prompt" sequences that can be prefilled
|
||||
# concurrently (long is defined by long_prefill_threshold)
|
||||
max_long_partial_prefills: int = 1
|
||||
|
||||
# calculate context length that determines which sequences are
|
||||
# considered "long"
|
||||
long_prefill_token_threshold: int = 0
|
||||
|
||||
# The number of slots to allocate per sequence per
|
||||
# step, beyond the known token ids. This is used in speculative
|
||||
# decoding to store KV activations of tokens which may or may not be
|
||||
@ -1537,6 +1548,18 @@ class SchedulerConfig:
|
||||
self.max_num_batched_tokens)
|
||||
|
||||
self.chunked_prefill_enabled = self.enable_chunked_prefill
|
||||
if self.max_num_partial_prefills > 1:
|
||||
if self.long_prefill_token_threshold == 0:
|
||||
self.long_prefill_token_threshold = int(self.max_model_len *
|
||||
0.04)
|
||||
|
||||
logger.info(
|
||||
"Concurrent partial prefills enabled with "
|
||||
"max_num_partial_prefills=%d, max_long_partial_prefills=%d, "
|
||||
"long_prefill_token_threshold=%d",
|
||||
self.max_num_partial_prefills, self.max_long_partial_prefills,
|
||||
self.long_prefill_token_threshold)
|
||||
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
@ -1568,6 +1591,29 @@ class SchedulerConfig:
|
||||
f"({self.num_scheduler_steps}) must be greater than or "
|
||||
"equal to 1.")
|
||||
|
||||
if self.max_num_partial_prefills < 1:
|
||||
raise ValueError(
|
||||
f"max_num_partial_prefills ({self.max_num_partial_prefills}) "
|
||||
"must be greater than or equal to 1.")
|
||||
elif self.max_num_partial_prefills > 1:
|
||||
if not self.chunked_prefill_enabled:
|
||||
raise ValueError("Chunked prefill must be enabled to set "
|
||||
"max_num_partial_prefills > 1.")
|
||||
|
||||
if self.long_prefill_token_threshold > self.max_model_len:
|
||||
raise ValueError(
|
||||
"long_prefill_token_threshold "
|
||||
f"({self.long_prefill_token_threshold}) cannot be greater "
|
||||
f"than the max_model_len ({self.max_model_len}).")
|
||||
|
||||
if (self.max_long_partial_prefills
|
||||
< 1) or (self.max_long_partial_prefills
|
||||
> self.max_num_partial_prefills):
|
||||
raise ValueError(
|
||||
f"max_long_partial_prefills ({self.max_long_partial_prefills}) "
|
||||
"must be greater than or equal to 1 and less than or equal to "
|
||||
f"max_num_partial_prefills ({self.max_num_partial_prefills}).")
|
||||
|
||||
@property
|
||||
def is_multi_step(self) -> bool:
|
||||
return self.num_scheduler_steps > 1
|
||||
|
@ -17,7 +17,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceGroupMetadataDelta,
|
||||
SequenceStatus)
|
||||
SequenceStage, SequenceStatus)
|
||||
from vllm.utils import Device, PyObjectCache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -39,6 +39,7 @@ class PreemptionMode(enum.Enum):
|
||||
recompute them when the sequences are resumed, treating the sequences as
|
||||
new prompts.
|
||||
"""
|
||||
|
||||
SWAP = enum.auto()
|
||||
RECOMPUTE = enum.auto()
|
||||
|
||||
@ -54,6 +55,7 @@ class SchedulingBudget:
|
||||
happen if we only have chunked prefill scheduling, we can remove this
|
||||
feature from the API when chunked prefill is enabled by default.
|
||||
"""
|
||||
|
||||
token_budget: int
|
||||
max_num_seqs: int
|
||||
_request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
|
||||
@ -132,6 +134,7 @@ class ScheduledSequenceGroup:
|
||||
@dataclass
|
||||
class SchedulerOutputs:
|
||||
"""The scheduling decision made from a scheduler."""
|
||||
|
||||
# Scheduled sequence groups.
|
||||
scheduled_seq_groups: GenericSequence[ScheduledSequenceGroup]
|
||||
# Number of prefill groups scheduled.
|
||||
@ -205,6 +208,7 @@ class SchedulerRunningOutputs:
|
||||
Could contain prefill (prefill that's chunked) or decodes. If there's not
|
||||
enough memory, it can be preempted (for recompute) or swapped out.
|
||||
"""
|
||||
|
||||
# Selected sequences that are running and in a decoding phase.
|
||||
decode_seq_groups: List[ScheduledSequenceGroup]
|
||||
# Selected sequences that are running and in a prefill phase.
|
||||
@ -246,6 +250,7 @@ class SchedulerSwappedInOutputs:
|
||||
|
||||
Could contain prefill (prefill that's chunked) or decodes.
|
||||
"""
|
||||
|
||||
# Selected sequences that are going to be swapped in and is in a
|
||||
# decoding phase.
|
||||
decode_seq_groups: List[ScheduledSequenceGroup]
|
||||
@ -280,6 +285,7 @@ class SchedulerPrefillOutputs:
|
||||
Could contain a fresh prefill requests or preempted requests that need
|
||||
to be recomputed from scratch.
|
||||
"""
|
||||
|
||||
# Selected sequences for prefill.
|
||||
seq_groups: List[ScheduledSequenceGroup]
|
||||
# Ignored sequence groups.
|
||||
@ -321,6 +327,100 @@ def scheduled_seq_group_builder():
|
||||
# return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PartialPrefillMetadata:
|
||||
"""Holds information about the partial prefills that are currently running
|
||||
during a single iteration of the Scheduler.
|
||||
When chunked prefill is enabled, we allow a certain number of seqs to be
|
||||
partially prefilled during each iteration. Having multiple partial prefills
|
||||
in flight allows us to minimize TTFT and avoid decode starvation in cases
|
||||
where a single sequence group with a very large prompt blocks the queue for
|
||||
too many iterations.
|
||||
The number of long prefill requests is limited so that smaller
|
||||
requests may jump the queue in front of them and get to the decode
|
||||
phase faster.
|
||||
"""
|
||||
|
||||
# A minimum bound on the total number of prefills to be scheduled during
|
||||
# this iteration
|
||||
schedulable_prefills: int
|
||||
|
||||
# The number of long prefill requests currently running
|
||||
long_prefills: int
|
||||
|
||||
scheduler_config: SchedulerConfig
|
||||
|
||||
def can_schedule(self, seq_group: SequenceGroup) -> bool:
|
||||
"""When concurrent partial prefills are enabled,
|
||||
we limit the number of long requests and only accept
|
||||
shorter requests from the queue while running them
|
||||
concurrently"""
|
||||
return not (seq_group.first_seq.get_num_new_tokens()
|
||||
> self.scheduler_config.long_prefill_token_threshold
|
||||
and self.long_prefills
|
||||
>= self.scheduler_config.max_long_partial_prefills
|
||||
and self.scheduler_config.max_num_partial_prefills > 1)
|
||||
|
||||
def maybe_increment_partial_prefills(self,
|
||||
seq_group: SequenceGroup) -> None:
|
||||
# When a new prefill is scheduled, we need to know if it is a
|
||||
# long request
|
||||
if (seq_group.first_seq.get_num_new_tokens()
|
||||
> self.scheduler_config.long_prefill_token_threshold):
|
||||
self.long_prefills += 1
|
||||
|
||||
@classmethod
|
||||
def from_queues(
|
||||
cls,
|
||||
running: Deque[SequenceGroup],
|
||||
waiting: Deque[SequenceGroup],
|
||||
scheduler_config: SchedulerConfig,
|
||||
) -> "PartialPrefillMetadata":
|
||||
"""Create a PartialPrefillMetadata object from the current state of
|
||||
the scheduler's queues.
|
||||
This accounts for the currently running prefill requests, and peeks into
|
||||
the waiting queue to see if there are more prefills to potentially be
|
||||
scheduled during this iteration."""
|
||||
prefills = 0
|
||||
long_prefills = 0
|
||||
|
||||
waiting_long_prefills = 0
|
||||
|
||||
for sg in running:
|
||||
if sg.first_seq.data.stage == SequenceStage.PREFILL:
|
||||
prefills += 1
|
||||
if (sg.first_seq.get_num_new_tokens()
|
||||
> scheduler_config.long_prefill_token_threshold):
|
||||
long_prefills += 1
|
||||
|
||||
for sg in waiting:
|
||||
# Don't bother looping through the rest of the queue if we know
|
||||
# there are already at
|
||||
# least max_partial_prefills requests to fill
|
||||
if prefills >= scheduler_config.max_num_partial_prefills:
|
||||
break
|
||||
|
||||
# Don't count long requests from the waiting queue if we aren't
|
||||
# going to schedule them anyway
|
||||
if (sg.first_seq.get_num_new_tokens()
|
||||
> scheduler_config.long_prefill_token_threshold):
|
||||
if (long_prefills + waiting_long_prefills
|
||||
>= scheduler_config.max_long_partial_prefills):
|
||||
continue
|
||||
waiting_long_prefills += 1
|
||||
prefills += 1
|
||||
|
||||
# NB: long_prefills and waiting_long_prefills are tracked separately.
|
||||
# We don't account for the waiting requests here because we need to use
|
||||
# this metadata to track how many have actually been scheduled.
|
||||
return PartialPrefillMetadata(
|
||||
schedulable_prefills=min(
|
||||
prefills, scheduler_config.max_num_partial_prefills),
|
||||
long_prefills=long_prefills,
|
||||
scheduler_config=scheduler_config,
|
||||
)
|
||||
|
||||
|
||||
class Scheduler:
|
||||
|
||||
def __init__(
|
||||
@ -360,7 +460,8 @@ class Scheduler:
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
sliding_window=self.cache_config.sliding_window,
|
||||
enable_caching=self.cache_config.enable_prefix_caching)
|
||||
enable_caching=self.cache_config.enable_prefix_caching,
|
||||
)
|
||||
|
||||
# Sequence groups in the WAITING state.
|
||||
# Contain new prefill or preempted requests.
|
||||
@ -421,6 +522,18 @@ class Scheduler:
|
||||
# for processing and deallocation by the free_finished_seq_groups()
|
||||
self._async_stopped: List[SequenceGroup] = []
|
||||
|
||||
# List with the chunk sizes to hand out to each sequence depending
|
||||
# on how many partial prefills are running. This is slightly faster than
|
||||
# running an integer division every time a prefill is scheduled.
|
||||
# This splits the budget evenly among all prefills.
|
||||
self.partial_prefill_budget_lookup_list = [0] * (
|
||||
self.scheduler_config.max_num_partial_prefills + 1)
|
||||
self.partial_prefill_budget_lookup_list[0] = (
|
||||
scheduler_config.max_num_batched_tokens)
|
||||
for i in range(1, self.scheduler_config.max_num_partial_prefills + 1):
|
||||
self.partial_prefill_budget_lookup_list[i] = (
|
||||
scheduler_config.max_num_batched_tokens // i)
|
||||
|
||||
@property
|
||||
def next_cache_id(self):
|
||||
return (self.cache_id + 1) % self.num_cache_iters
|
||||
@ -500,8 +613,8 @@ class Scheduler:
|
||||
self.block_manager.free_cross(seq_group)
|
||||
|
||||
def has_unfinished_seqs(self) -> bool:
|
||||
return len(self.waiting) != 0 or len(self.running) != 0 or len(
|
||||
self.swapped) != 0
|
||||
return (len(self.waiting) != 0 or len(self.running) != 0
|
||||
or len(self.swapped) != 0)
|
||||
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
return self.block_manager.get_prefix_cache_hit_rate(device)
|
||||
@ -523,6 +636,7 @@ class Scheduler:
|
||||
budget: SchedulingBudget,
|
||||
curr_loras: Optional[Set[int]],
|
||||
enable_chunking: bool = False,
|
||||
partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
|
||||
) -> SchedulerRunningOutputs:
|
||||
"""Schedule sequence groups that are running.
|
||||
|
||||
@ -537,12 +651,14 @@ class Scheduler:
|
||||
chunked number of tokens are scheduled if
|
||||
`budget.num_batched_tokens` has not enough capacity to schedule
|
||||
all tokens.
|
||||
|
||||
partial_prefill_metadata: information about the partial prefills
|
||||
that are currently running
|
||||
|
||||
Returns:
|
||||
SchedulerRunningOutputs.
|
||||
"""
|
||||
ret: SchedulerRunningOutputs = \
|
||||
self._scheduler_running_outputs_cache[self.cache_id].get_object()
|
||||
ret: SchedulerRunningOutputs = self._scheduler_running_outputs_cache[
|
||||
self.cache_id].get_object()
|
||||
ret.blocks_to_swap_out.clear()
|
||||
ret.blocks_to_copy.clear()
|
||||
ret.decode_seq_groups.clear()
|
||||
@ -577,10 +693,14 @@ class Scheduler:
|
||||
# 2. If a sequence is running with non-chunked prefill, then
|
||||
# there it's a decoding sequence, and the cached tokens info is
|
||||
# irrelevant.
|
||||
num_uncached_new_tokens, _ = (
|
||||
num_uncached_new_tokens, _ = \
|
||||
self._get_num_new_uncached_and_cached_tokens(
|
||||
seq_group, SequenceStatus.RUNNING, enable_chunking,
|
||||
budget))
|
||||
seq_group,
|
||||
SequenceStatus.RUNNING,
|
||||
enable_chunking,
|
||||
budget,
|
||||
partial_prefill_metadata,
|
||||
)
|
||||
|
||||
num_running_tokens = num_uncached_new_tokens
|
||||
if num_running_tokens == 0:
|
||||
@ -593,8 +713,8 @@ class Scheduler:
|
||||
# to process the final tokens. The check below avoids this extra
|
||||
# decode run when the model max len is reached, in order to avoid
|
||||
# a memory overflow.
|
||||
if self.use_async_output_proc and seq_group.seqs[0].get_len(
|
||||
) > self.scheduler_config.max_model_len:
|
||||
if (self.use_async_output_proc and seq_group.seqs[0].get_len()
|
||||
> self.scheduler_config.max_model_len):
|
||||
self._async_stopped.append(seq_group)
|
||||
continue
|
||||
|
||||
@ -653,8 +773,9 @@ class Scheduler:
|
||||
self._append_slots(seq_group, blocks_to_copy, enable_chunking)
|
||||
is_prefill = seq_group.is_prefill()
|
||||
|
||||
scheduled_seq_group: ScheduledSequenceGroup = \
|
||||
self._scheduled_seq_group_cache[self.cache_id].get_object()
|
||||
scheduled_seq_group: ScheduledSequenceGroup = (
|
||||
self._scheduled_seq_group_cache[
|
||||
self.cache_id].get_object())
|
||||
scheduled_seq_group.seq_group = seq_group
|
||||
if is_prefill:
|
||||
scheduled_seq_group.token_chunk_size = num_running_tokens
|
||||
@ -731,7 +852,8 @@ class Scheduler:
|
||||
logger.warning(
|
||||
"Failing the request %s because there's not enough kv "
|
||||
"cache blocks to run the entire sequence.",
|
||||
seq_group.request_id)
|
||||
seq_group.request_id,
|
||||
)
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||
infeasible_seq_groups.append(seq_group)
|
||||
@ -800,16 +922,17 @@ class Scheduler:
|
||||
)
|
||||
|
||||
def _get_prompt_limit(self, seq_group: SequenceGroup) -> int:
|
||||
if self.scheduler_config.chunked_prefill_enabled and \
|
||||
not self.scheduler_config.is_multi_step:
|
||||
if (self.scheduler_config.chunked_prefill_enabled
|
||||
and not self.scheduler_config.is_multi_step):
|
||||
prompt_limit = self.scheduler_config.max_model_len
|
||||
else:
|
||||
prompt_limit = min(self.scheduler_config.max_model_len,
|
||||
self.scheduler_config.max_num_batched_tokens)
|
||||
prompt_limit = min(
|
||||
self.scheduler_config.max_model_len,
|
||||
self.scheduler_config.max_num_batched_tokens,
|
||||
)
|
||||
|
||||
# Model is fine tuned with long context. Return the fine tuned max_len.
|
||||
if (seq_group.lora_request
|
||||
and seq_group.lora_request.long_lora_max_len):
|
||||
if seq_group.lora_request and seq_group.lora_request.long_lora_max_len:
|
||||
assert prompt_limit <= seq_group.lora_request.long_lora_max_len
|
||||
return seq_group.lora_request.long_lora_max_len
|
||||
else:
|
||||
@ -817,7 +940,7 @@ class Scheduler:
|
||||
|
||||
def _get_priority(self,
|
||||
seq_group: SequenceGroup) -> Tuple[Optional[int], float]:
|
||||
""" Get the priority of the sequence group.
|
||||
"""Get the priority of the sequence group.
|
||||
Highest preference to user-defined priority, followed by arrival time.
|
||||
Args:
|
||||
seq_group: The sequence group input.
|
||||
@ -850,14 +973,14 @@ class Scheduler:
|
||||
if waiting_queue:
|
||||
seq_group = waiting_queue.popleft()
|
||||
num_new_seqs = seq_group.get_max_num_running_seqs()
|
||||
num_new_tokens_uncached, _ = (
|
||||
num_new_tokens_uncached, _ = \
|
||||
self._get_num_new_uncached_and_cached_tokens(
|
||||
seq_group, SequenceStatus.WAITING, False, budget))
|
||||
seq_group, SequenceStatus.WAITING, False, budget)
|
||||
|
||||
#Only preempt if priority inversion exists
|
||||
# Only preempt if priority inversion exists
|
||||
while running_queue and self._get_priority(
|
||||
running_queue[-1]) > self._get_priority(seq_group):
|
||||
#Only preempt if waiting sequence cannot be allocated
|
||||
# Only preempt if waiting sequence cannot be allocated
|
||||
can_allocate = self.block_manager.can_allocate(seq_group)
|
||||
if (num_new_tokens_uncached > 0
|
||||
and can_allocate == AllocStatus.OK
|
||||
@ -867,7 +990,7 @@ class Scheduler:
|
||||
)):
|
||||
break
|
||||
|
||||
#Adjust budget to remove the victim sequence group
|
||||
# Adjust budget to remove the victim sequence group
|
||||
vseq_group = running_queue.pop()
|
||||
num_running_tokens_uncached, _ = (
|
||||
self._get_num_new_uncached_and_cached_tokens(
|
||||
@ -878,11 +1001,11 @@ class Scheduler:
|
||||
budget.subtract_num_seqs(vseq_group.request_id,
|
||||
num_running_seqs)
|
||||
|
||||
#Preempt out the victim sequence group
|
||||
# Preempt out the victim sequence group
|
||||
self._preempt(vseq_group, blocks_to_swap_out)
|
||||
waiting_queue.appendleft(vseq_group)
|
||||
force_preemption_count += 1
|
||||
#Put the sequence back into the waiting queue
|
||||
# Put the sequence back into the waiting queue
|
||||
waiting_queue.appendleft(seq_group)
|
||||
|
||||
waiting_queue = deque(sorted(waiting_queue, key=self._get_priority))
|
||||
@ -896,6 +1019,7 @@ class Scheduler:
|
||||
budget: SchedulingBudget,
|
||||
curr_loras: Optional[Set[int]],
|
||||
enable_chunking: bool = False,
|
||||
partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
|
||||
) -> SchedulerPrefillOutputs:
|
||||
"""Schedule sequence groups that are in prefill stage.
|
||||
|
||||
@ -916,10 +1040,20 @@ class Scheduler:
|
||||
chunked number of tokens are scheduled if
|
||||
`budget.num_batched_tokens` has not enough capacity to schedule
|
||||
all tokens.
|
||||
partial_prefill_metadata: information about the partial prefills
|
||||
that are currently running
|
||||
|
||||
Returns:
|
||||
SchedulerPrefillOutputs.
|
||||
"""
|
||||
if budget.remaining_token_budget() == 0:
|
||||
# Do nothing: Can't add any more prefill anyway
|
||||
return SchedulerPrefillOutputs(
|
||||
seq_groups=[],
|
||||
ignored_seq_groups=[],
|
||||
num_lookahead_slots=self._get_num_lookahead_slots(
|
||||
is_prefill=True, enable_chunking=enable_chunking),
|
||||
)
|
||||
ignored_seq_groups: List[SequenceGroup] = []
|
||||
seq_groups: List[ScheduledSequenceGroup] = []
|
||||
|
||||
@ -933,10 +1067,19 @@ class Scheduler:
|
||||
assert len(waiting_seqs) == 1, (
|
||||
"Waiting sequence group should have only one prompt "
|
||||
"sequence.")
|
||||
if (partial_prefill_metadata is not None
|
||||
and not partial_prefill_metadata.can_schedule(seq_group)):
|
||||
leftover_waiting_sequences.appendleft(seq_group)
|
||||
waiting_queue.popleft()
|
||||
continue
|
||||
num_new_tokens_uncached, num_new_tokens_cached = (
|
||||
self._get_num_new_uncached_and_cached_tokens(
|
||||
seq_group, SequenceStatus.WAITING, enable_chunking,
|
||||
budget))
|
||||
seq_group,
|
||||
SequenceStatus.WAITING,
|
||||
enable_chunking,
|
||||
budget,
|
||||
partial_prefill_metadata=partial_prefill_metadata,
|
||||
))
|
||||
num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached
|
||||
|
||||
if not enable_chunking:
|
||||
@ -947,7 +1090,10 @@ class Scheduler:
|
||||
if num_new_tokens > prompt_limit:
|
||||
logger.warning(
|
||||
"Input prompt (%d tokens) is too long"
|
||||
" and exceeds limit of %d", num_new_tokens, prompt_limit)
|
||||
" and exceeds limit of %d",
|
||||
num_new_tokens,
|
||||
prompt_limit,
|
||||
)
|
||||
for seq in waiting_seqs:
|
||||
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||
ignored_seq_groups.append(seq_group)
|
||||
@ -968,7 +1114,9 @@ class Scheduler:
|
||||
logger.warning(
|
||||
"Input prompt (%d tokens) + lookahead slots (%d) is "
|
||||
"too long and exceeds the capacity of block_manager",
|
||||
num_new_tokens, num_lookahead_slots)
|
||||
num_new_tokens,
|
||||
num_lookahead_slots,
|
||||
)
|
||||
for seq in waiting_seqs:
|
||||
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||
ignored_seq_groups.append(seq_group)
|
||||
@ -1009,6 +1157,10 @@ class Scheduler:
|
||||
waiting_queue.popleft()
|
||||
self._allocate_and_set_running(seq_group)
|
||||
|
||||
if partial_prefill_metadata is not None:
|
||||
partial_prefill_metadata.maybe_increment_partial_prefills(
|
||||
seq_group)
|
||||
|
||||
if enable_chunking and self.scheduler_config.is_multi_step:
|
||||
blocks_to_copy: List[Tuple[int, int]] = []
|
||||
# init_multi_step_from_lookahead_slots happens in append_slots
|
||||
@ -1024,7 +1176,8 @@ class Scheduler:
|
||||
num_scheduler_steps=self.scheduler_config.
|
||||
num_scheduler_steps,
|
||||
is_multi_step=self.scheduler_config.is_multi_step,
|
||||
enable_chunking=enable_chunking)
|
||||
enable_chunking=enable_chunking,
|
||||
)
|
||||
|
||||
seq_groups.append(
|
||||
ScheduledSequenceGroup(seq_group=seq_group,
|
||||
@ -1045,11 +1198,12 @@ class Scheduler:
|
||||
seq_groups=seq_groups,
|
||||
ignored_seq_groups=ignored_seq_groups,
|
||||
num_lookahead_slots=self._get_num_lookahead_slots(
|
||||
is_prefill=True, enable_chunking=enable_chunking))
|
||||
is_prefill=True, enable_chunking=enable_chunking),
|
||||
)
|
||||
|
||||
def _schedule_default(self) -> SchedulerOutputs:
|
||||
"""Schedule queued requests.
|
||||
|
||||
|
||||
The current policy is designed to optimize the throughput. First,
|
||||
it batches as many prefill requests as possible. And it schedules
|
||||
decodes. If there's a pressure on GPU memory, decode requests can
|
||||
@ -1065,9 +1219,9 @@ class Scheduler:
|
||||
for seq_group in self.running:
|
||||
budget.add_num_seqs(seq_group.request_id,
|
||||
seq_group.get_max_num_running_seqs())
|
||||
curr_loras = set(
|
||||
curr_loras = (set(
|
||||
seq_group.lora_int_id for seq_group in self.running
|
||||
if seq_group.lora_int_id > 0) if self.lora_enabled else None
|
||||
if seq_group.lora_int_id > 0) if self.lora_enabled else None)
|
||||
|
||||
prefills = SchedulerPrefillOutputs.create_empty()
|
||||
running_scheduled = SchedulerRunningOutputs.create_empty()
|
||||
@ -1093,9 +1247,10 @@ class Scheduler:
|
||||
|
||||
# If any sequence group is preempted, do not swap in any sequence
|
||||
# group. because it means there's no slot for new running requests.
|
||||
if len(running_scheduled.preempted) + len(
|
||||
running_scheduled.swapped_out) == 0:
|
||||
swapped_in = self._schedule_swapped(budget, curr_loras)
|
||||
if (len(running_scheduled.preempted) +
|
||||
len(running_scheduled.swapped_out) == 0):
|
||||
swapped_in = \
|
||||
self._schedule_swapped(budget, curr_loras)
|
||||
|
||||
assert (budget.num_batched_tokens
|
||||
<= self.scheduler_config.max_num_batched_tokens)
|
||||
@ -1115,8 +1270,8 @@ class Scheduler:
|
||||
|
||||
# Update swapped requests.
|
||||
self.swapped.extend(running_scheduled.swapped_out)
|
||||
preempted = (len(running_scheduled.preempted) +
|
||||
len(running_scheduled.swapped_out))
|
||||
preempted = len(running_scheduled.preempted) + len(
|
||||
running_scheduled.swapped_out)
|
||||
|
||||
# There should be no prefill from running queue because this policy
|
||||
# doesn't allow chunked prefills.
|
||||
@ -1154,7 +1309,7 @@ class Scheduler:
|
||||
|
||||
def _schedule_chunked_prefill(self) -> SchedulerOutputs:
|
||||
"""Schedule queued requests.
|
||||
|
||||
|
||||
Chunked prefill allows to chunk prefill requests, batch them together
|
||||
with decode requests. This policy 1. schedule as many decoding requests
|
||||
as possible. 2. schedule chunked prefill requests that are not
|
||||
@ -1175,10 +1330,20 @@ class Scheduler:
|
||||
prefills = SchedulerPrefillOutputs.create_empty()
|
||||
swapped_in = SchedulerSwappedInOutputs.create_empty()
|
||||
|
||||
# Create partial prefill metadata
|
||||
partial_prefill_metadata = PartialPrefillMetadata.from_queues(
|
||||
running=self.running,
|
||||
waiting=self.waiting,
|
||||
scheduler_config=self.scheduler_config,
|
||||
)
|
||||
|
||||
# Decoding should be always scheduled first by fcfs.
|
||||
running_scheduled = self._schedule_running(budget,
|
||||
curr_loras,
|
||||
enable_chunking=True)
|
||||
running_scheduled = self._schedule_running(
|
||||
budget,
|
||||
curr_loras,
|
||||
enable_chunking=True,
|
||||
partial_prefill_metadata=partial_prefill_metadata,
|
||||
)
|
||||
|
||||
# Schedule swapped out requests.
|
||||
# If preemption happens, it means we don't have space for swap-in.
|
||||
@ -1186,9 +1351,12 @@ class Scheduler:
|
||||
running_scheduled.swapped_out) == 0:
|
||||
swapped_in = self._schedule_swapped(budget, curr_loras)
|
||||
|
||||
prefills = self._schedule_prefills(budget,
|
||||
curr_loras,
|
||||
enable_chunking=True)
|
||||
prefills = self._schedule_prefills(
|
||||
budget,
|
||||
curr_loras,
|
||||
enable_chunking=True,
|
||||
partial_prefill_metadata=partial_prefill_metadata,
|
||||
)
|
||||
|
||||
assert (budget.num_batched_tokens
|
||||
<= self.scheduler_config.max_num_batched_tokens)
|
||||
@ -1207,8 +1375,15 @@ class Scheduler:
|
||||
[s.seq_group for s in swapped_in.prefill_seq_groups])
|
||||
self.running.extend(
|
||||
[s.seq_group for s in running_scheduled.decode_seq_groups])
|
||||
# Because multiple prefills may be running concurrently, we need to
|
||||
# make sure that prefills which are scheduled to finish are listed
|
||||
# before those that won't. This is so that on the next scheduling
|
||||
# iteration when they have transitioned to the decode stage, they are
|
||||
# properly prioritized over sequences that are still in the prefill
|
||||
# stage.
|
||||
self.running.extend(
|
||||
[s.seq_group for s in running_scheduled.prefill_seq_groups])
|
||||
self._order_finishing_prefills_first(
|
||||
running_scheduled.prefill_seq_groups))
|
||||
self.running.extend([s.seq_group for s in prefills.seq_groups])
|
||||
|
||||
# Update swapped requests.
|
||||
@ -1225,7 +1400,7 @@ class Scheduler:
|
||||
# If all prompts, then we set num_lookahead_slots to 0
|
||||
# this allows us to go through the `no_spec` path in
|
||||
# `spec_decode_worker.py`
|
||||
all_prefills = (len(scheduled_seq_groups) == num_prefill_groups)
|
||||
all_prefills = len(scheduled_seq_groups) == num_prefill_groups
|
||||
num_lookahead_slots = (0 if
|
||||
(all_prefills
|
||||
and not self.scheduler_config.is_multi_step)
|
||||
@ -1247,6 +1422,21 @@ class Scheduler:
|
||||
len(running_scheduled.swapped_out)),
|
||||
)
|
||||
|
||||
def _order_finishing_prefills_first(
|
||||
self, scheduled_prefill_seqs: List[ScheduledSequenceGroup]
|
||||
) -> List[SequenceGroup]:
|
||||
"""Returns a list of prefilling SequenceGroups where sequences that are
|
||||
scheduled to finish prefilling are listed first"""
|
||||
finishing = [
|
||||
s.seq_group for s in scheduled_prefill_seqs
|
||||
if s.seq_group.get_num_uncomputed_tokens() == s.token_chunk_size
|
||||
]
|
||||
not_finishing = [
|
||||
s.seq_group for s in scheduled_prefill_seqs
|
||||
if s.seq_group.get_num_uncomputed_tokens() != s.token_chunk_size
|
||||
]
|
||||
return finishing + not_finishing
|
||||
|
||||
def _schedule(self) -> SchedulerOutputs:
|
||||
"""Schedule queued requests."""
|
||||
if self.scheduler_config.chunked_prefill_enabled:
|
||||
@ -1385,10 +1575,12 @@ class Scheduler:
|
||||
# between engine and worker.
|
||||
# the subsequent comms can still use delta, but
|
||||
# `multi_modal_data` will be None.
|
||||
multi_modal_data=seq_group.multi_modal_data
|
||||
if scheduler_outputs.num_prefill_groups > 0 else None,
|
||||
multi_modal_placeholders=seq_group.multi_modal_placeholders
|
||||
if scheduler_outputs.num_prefill_groups > 0 else None,
|
||||
multi_modal_data=(seq_group.multi_modal_data
|
||||
if scheduler_outputs.num_prefill_groups
|
||||
> 0 else None),
|
||||
multi_modal_placeholders=(
|
||||
seq_group.multi_modal_placeholders
|
||||
if scheduler_outputs.num_prefill_groups > 0 else None),
|
||||
mm_processor_kwargs=seq_group.mm_processor_kwargs,
|
||||
prompt_adapter_request=seq_group.prompt_adapter_request,
|
||||
)
|
||||
@ -1494,10 +1686,12 @@ class Scheduler:
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
def _append_slots(self,
|
||||
seq_group: SequenceGroup,
|
||||
blocks_to_copy: List[Tuple[int, int]],
|
||||
enable_chunking: bool = False) -> None:
|
||||
def _append_slots(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
blocks_to_copy: List[Tuple[int, int]],
|
||||
enable_chunking: bool = False,
|
||||
) -> None:
|
||||
"""Appends new slots to the sequences in the given sequence group.
|
||||
|
||||
Args:
|
||||
@ -1518,7 +1712,8 @@ class Scheduler:
|
||||
num_lookahead_slots,
|
||||
num_scheduler_steps=self.scheduler_config.num_scheduler_steps,
|
||||
is_multi_step=self.scheduler_config.is_multi_step,
|
||||
enable_chunking=enable_chunking)
|
||||
enable_chunking=enable_chunking,
|
||||
)
|
||||
|
||||
seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING
|
||||
if self.scheduler_config.is_multi_step and enable_chunking:
|
||||
@ -1561,8 +1756,11 @@ class Scheduler:
|
||||
"not enough KV cache space. This can affect the end-to-end "
|
||||
"performance. Increase gpu_memory_utilization or "
|
||||
"tensor_parallel_size to provide more KV cache memory. "
|
||||
"total_num_cumulative_preemption=%d", seq_group.request_id,
|
||||
preemption_mode, self.num_cumulative_preemption + 1)
|
||||
"total_num_cumulative_preemption=%d",
|
||||
seq_group.request_id,
|
||||
preemption_mode,
|
||||
self.num_cumulative_preemption + 1,
|
||||
)
|
||||
self.num_cumulative_preemption += 1
|
||||
|
||||
if preemption_mode == PreemptionMode.RECOMPUTE:
|
||||
@ -1668,6 +1866,7 @@ class Scheduler:
|
||||
status: SequenceStatus,
|
||||
enable_chunking: bool,
|
||||
budget: SchedulingBudget,
|
||||
partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Returns the number of new uncached and cached tokens to schedule for a
|
||||
@ -1691,6 +1890,8 @@ class Scheduler:
|
||||
to schedule.
|
||||
enable_chunking: Whether to chunk the number of tokens to compute.
|
||||
budget: The budget to chunk the number of tokens to compute.
|
||||
partial_prefill_metadata: information about the partial prefills
|
||||
that are currently running
|
||||
|
||||
|
||||
Returns:
|
||||
@ -1768,6 +1969,8 @@ class Scheduler:
|
||||
budget,
|
||||
self._get_prompt_limit(seq_group),
|
||||
num_uncached_new_tokens,
|
||||
self.partial_prefill_budget_lookup_list,
|
||||
partial_prefill_metadata,
|
||||
)
|
||||
|
||||
return num_uncached_new_tokens, num_cached_new_tokens
|
||||
@ -1779,6 +1982,8 @@ class Scheduler:
|
||||
budget: SchedulingBudget,
|
||||
prompt_limit: int,
|
||||
num_new_tokens: int,
|
||||
partial_prefill_budget_lookup_list: List[int],
|
||||
partial_prefill_metadata: Optional[PartialPrefillMetadata] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Chunks the number of new tokens to schedule based on the budget when
|
||||
@ -1811,29 +2016,31 @@ class Scheduler:
|
||||
# the sequence.
|
||||
return num_new_tokens
|
||||
|
||||
return (0 if num_new_tokens > remaining_token_budget else
|
||||
num_new_tokens)
|
||||
return 0 if num_new_tokens > \
|
||||
remaining_token_budget else num_new_tokens
|
||||
|
||||
# Get the number of tokens to allocate to this prefill slot
|
||||
prefill_slot_budget = (
|
||||
remaining_token_budget if partial_prefill_metadata is None else
|
||||
partial_prefill_budget_lookup_list[
|
||||
partial_prefill_metadata.schedulable_prefills])
|
||||
|
||||
if cache_config.enable_prefix_caching:
|
||||
# Adjust the remaining token budget to be divisible by the block
|
||||
# size when prefix caching is enabled.
|
||||
|
||||
# When prefix caching is enabled, we always allocate
|
||||
# the number of new tokens that is dividable by the block
|
||||
# size to avoid partial block matching.
|
||||
# When prefix caching is enabled and we're partially prefilling
|
||||
# a sequence, we always allocate a number of new tokens that is
|
||||
# divisible by the block size to avoid partial block matching.
|
||||
block_size = cache_config.block_size
|
||||
remainder = budget.token_budget % block_size
|
||||
if remainder != 0:
|
||||
raise ValueError("When enabling chunked prefill and "
|
||||
"prefix caching, max_num_batched_tokens "
|
||||
"(chunk size) must be dividable by "
|
||||
"block size, but got chunk_size "
|
||||
f"({budget.token_budget}) % block_size "
|
||||
f"({block_size}) = {remainder}")
|
||||
# Round down to block size.
|
||||
remaining_token_budget = (remaining_token_budget // block_size *
|
||||
block_size)
|
||||
# Don't exceed either the total budget or slot budget.
|
||||
# Take min of those and get the next lowest multiple of the
|
||||
# block size:
|
||||
remaining_token_budget = (
|
||||
min(remaining_token_budget, prefill_slot_budget) //
|
||||
block_size) * block_size
|
||||
# NB: In the case where num_new_tokens < budget, we are
|
||||
# finishing prefill for this sequence, so we do not need to
|
||||
# allocate a full block.
|
||||
|
||||
num_new_tokens = min(num_new_tokens, remaining_token_budget)
|
||||
num_new_tokens = min(num_new_tokens, remaining_token_budget,
|
||||
prefill_slot_budget)
|
||||
|
||||
return num_new_tokens
|
||||
|
@ -120,6 +120,9 @@ class EngineArgs:
|
||||
cpu_offload_gb: float = 0 # GiB
|
||||
gpu_memory_utilization: float = 0.90
|
||||
max_num_batched_tokens: Optional[int] = None
|
||||
max_num_partial_prefills: Optional[int] = 1
|
||||
max_long_partial_prefills: Optional[int] = 1
|
||||
long_prefill_token_threshold: Optional[int] = 0
|
||||
max_num_seqs: Optional[int] = None
|
||||
max_logprobs: int = 20 # Default value for OpenAI Chat Completions API
|
||||
disable_log_stats: bool = False
|
||||
@ -515,6 +518,31 @@ class EngineArgs:
|
||||
default=EngineArgs.max_num_batched_tokens,
|
||||
help='Maximum number of batched tokens per '
|
||||
'iteration.')
|
||||
parser.add_argument(
|
||||
"--max-num-partial-prefills",
|
||||
type=int,
|
||||
default=EngineArgs.max_num_partial_prefills,
|
||||
help="For chunked prefill, the max number of concurrent \
|
||||
partial prefills."
|
||||
"Defaults to 1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-long-partial-prefills",
|
||||
type=int,
|
||||
default=EngineArgs.max_long_partial_prefills,
|
||||
help="For chunked prefill, the maximum number of prompts longer "
|
||||
"than --long-prefill-token-threshold that will be prefilled "
|
||||
"concurrently. Setting this less than --max-num-partial-prefills "
|
||||
"will allow shorter prompts to jump the queue in front of longer "
|
||||
"prompts in some cases, improving latency. Defaults to 1.")
|
||||
parser.add_argument(
|
||||
"--long-prefill-token-threshold",
|
||||
type=float,
|
||||
default=EngineArgs.long_prefill_token_threshold,
|
||||
help="For chunked prefill, a request is considered long if the "
|
||||
"prompt is longer than this number of tokens. Defaults to 4%% of "
|
||||
"the model's context length.",
|
||||
)
|
||||
parser.add_argument('--max-num-seqs',
|
||||
type=int,
|
||||
default=EngineArgs.max_num_seqs,
|
||||
@ -1244,7 +1272,11 @@ class EngineArgs:
|
||||
multi_step_stream_outputs=self.multi_step_stream_outputs,
|
||||
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
|
||||
and parallel_config.use_ray),
|
||||
policy=self.scheduling_policy)
|
||||
policy=self.scheduling_policy,
|
||||
max_num_partial_prefills=self.max_num_partial_prefills,
|
||||
max_long_partial_prefills=self.max_long_partial_prefills,
|
||||
long_prefill_token_threshold=self.long_prefill_token_threshold,
|
||||
)
|
||||
lora_config = LoRAConfig(
|
||||
bias_enabled=self.enable_lora_bias,
|
||||
max_lora_rank=self.max_lora_rank,
|
||||
|
@ -958,7 +958,9 @@ def get_logprobs(
|
||||
if len(query_indices) == 0:
|
||||
empty_sampled_logprob: SampleLogprobs = []
|
||||
empty_prompt_logprob: Optional[PromptLogprobs] = None
|
||||
return [empty_prompt_logprob], [empty_sampled_logprob]
|
||||
num_seq_groups = len(sampling_metadata.seq_groups)
|
||||
return [empty_prompt_logprob
|
||||
] * num_seq_groups, [empty_sampled_logprob] * num_seq_groups
|
||||
|
||||
selected_logprobs, ranks = None, None
|
||||
top_logprobs, top_token_ids = None, None
|
||||
@ -1225,6 +1227,10 @@ def _build_sampler_output(
|
||||
assert sample_logprobs is not None
|
||||
assert not isinstance(maybe_deferred_sample_results,
|
||||
SampleResultArgsType)
|
||||
assert len(sampling_metadata.seq_groups) \
|
||||
== len(maybe_deferred_sample_results) \
|
||||
== len(prompt_logprobs) \
|
||||
== len(sample_logprobs)
|
||||
deferred_sample_results_args = None
|
||||
|
||||
for (seq_group, sample_result, group_prompt_logprobs,
|
||||
|
Reference in New Issue
Block a user