mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Speculative decoding 4/9] Lookahead scheduling for speculative decoding (#3250)
This commit is contained in:
@ -77,6 +77,159 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
|
||||
assert baseline_token_ids == test_token_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Use a small model for a fast test.
|
||||
"model": "facebook/opt-125m",
|
||||
|
||||
# skip cuda graph creation for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Use a large block size to trigger more copy-on-writes.
|
||||
"block_size": 32,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{
|
||||
"use_v2_block_manager": False
|
||||
}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
|
||||
@pytest.mark.parametrize("batch_size", [10])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
|
||||
test_llm_generator, batch_size):
|
||||
"""Verify beam search equality with block manager v1 and v2.
|
||||
|
||||
This requires copy-on-writes; if the v1 and v2 output is the same, then
|
||||
we have some confidence cow is working.
|
||||
"""
|
||||
output_len = 128
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
use_beam_search=True,
|
||||
best_of=2,
|
||||
)
|
||||
|
||||
print('Getting token ids from block manager v1')
|
||||
baseline_token_ids = get_token_ids_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
|
||||
print('Getting token ids from block manager v2')
|
||||
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
|
||||
prompts, sampling_params)
|
||||
|
||||
for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
|
||||
test_token_ids):
|
||||
assert expected_token_ids == actual_token_ids
|
||||
|
||||
assert baseline_token_ids == test_token_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Use a small model for a fast test.
|
||||
"model": "facebook/opt-125m",
|
||||
|
||||
# Our prompts will generate 128 tokens; since the prompts themselves are
|
||||
# small, we don't need much KV space beyond 128.
|
||||
"max_model_len": 160,
|
||||
|
||||
# skip cuda graph creation for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Lookahead scheduling only supported in v2 block manager.
|
||||
"use_v2_block_manager": True,
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"block_size": 16,
|
||||
|
||||
# Allow only 2 sequences of ~128 tokens in worst case.
|
||||
# Note 8 = 128/block_size
|
||||
"forced_num_gpu_blocks": 2 * (8 + 1),
|
||||
},
|
||||
{
|
||||
"block_size": 8,
|
||||
|
||||
# Allow only 2 sequences of ~128 tokens in worst case.
|
||||
# Note 16 = 128/block_size
|
||||
"forced_num_gpu_blocks": 2 * (16 + 1),
|
||||
}
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{
|
||||
"num_lookahead_slots": 0,
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[{
|
||||
# We run one test with block_size < lookahead_slots, one test with
|
||||
# block_size > lookahead_slots
|
||||
"num_lookahead_slots": 10,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size):
|
||||
"""Verify vLLM produces the same output with greedy sampling, when lookahead
|
||||
scheduling is used vs. not.
|
||||
|
||||
Lookahead scheduling is not expected to modify the output, as it simply
|
||||
allocates empty slots ahead of the known token ids in a sliding fashion.
|
||||
|
||||
This test constrains the total number of blocks to force preemption. It also
|
||||
varies the block size so that the lookahead size is less than and greater
|
||||
than the block size.
|
||||
"""
|
||||
output_len = 128
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
print('Getting token ids without lookahead scheduling')
|
||||
baseline_token_ids = get_token_ids_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
|
||||
print('Getting token ids with lookahead scheduling')
|
||||
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
|
||||
prompts, sampling_params)
|
||||
|
||||
for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
|
||||
test_token_ids):
|
||||
assert expected_token_ids == actual_token_ids
|
||||
|
||||
assert baseline_token_ids == test_token_ids
|
||||
|
||||
|
||||
def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
|
||||
for llm in llm_generator:
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
|
103
tests/core/block/test_block_manager_v2.py
Normal file
103
tests/core/block/test_block_manager_v2.py
Normal file
@ -0,0 +1,103 @@
|
||||
import pytest
|
||||
|
||||
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
|
||||
from vllm.core.interfaces import AllocStatus
|
||||
from vllm.sequence import Logprob, SequenceStatus
|
||||
from vllm.utils import chunk_list
|
||||
|
||||
from ..utils import create_seq_group
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("num_gpu_blocks", [8, 40, 80])
|
||||
@pytest.mark.parametrize("num_seqs_per_group", [1, 4])
|
||||
@pytest.mark.parametrize("watermark", [0.0, 0.5])
|
||||
def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int,
|
||||
num_gpu_blocks: int, watermark: float):
|
||||
block_manager = BlockSpaceManagerV2(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=1024,
|
||||
watermark=watermark,
|
||||
)
|
||||
num_watermark_blocks = int(watermark * num_gpu_blocks)
|
||||
|
||||
num_output_blocks_per_seq = 1
|
||||
|
||||
# NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but
|
||||
# the current implementation assumes all seqs are new prompts / don't have
|
||||
# different output lens.
|
||||
num_output_blocks = num_output_blocks_per_seq
|
||||
|
||||
for num_prompt_blocks in range(1, num_gpu_blocks - num_output_blocks):
|
||||
seq_group = create_seq_group(
|
||||
seq_prompt_len=block_size * num_prompt_blocks,
|
||||
seq_output_lens=[
|
||||
block_size * num_output_blocks_per_seq
|
||||
for _ in range(num_seqs_per_group)
|
||||
],
|
||||
)
|
||||
|
||||
assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks
|
||||
|
||||
can_allocate_result = block_manager.can_allocate(seq_group)
|
||||
|
||||
num_required_blocks = num_prompt_blocks + num_output_blocks
|
||||
|
||||
if num_gpu_blocks - num_required_blocks < num_watermark_blocks:
|
||||
assert can_allocate_result == AllocStatus.NEVER
|
||||
elif num_gpu_blocks >= num_required_blocks:
|
||||
assert can_allocate_result == AllocStatus.OK
|
||||
else:
|
||||
assert can_allocate_result == AllocStatus.LATER
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [1, 8])
|
||||
@pytest.mark.parametrize("prompt_len", [1, 7, 8])
|
||||
@pytest.mark.parametrize("num_slots_to_append", [1, 8, 129])
|
||||
@pytest.mark.parametrize("num_lookahead_slots", [0, 10])
|
||||
def test_append_slots(block_size, prompt_len, num_slots_to_append,
|
||||
num_lookahead_slots):
|
||||
"""Verify append_slots consumes the correct number of blocks from the block
|
||||
table.
|
||||
"""
|
||||
|
||||
num_gpu_blocks = 1024
|
||||
watermark = 0.1
|
||||
block_manager = BlockSpaceManagerV2(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=0,
|
||||
watermark=watermark,
|
||||
)
|
||||
|
||||
seq_group = create_seq_group(
|
||||
seq_prompt_len=prompt_len,
|
||||
seq_output_lens=[0],
|
||||
)
|
||||
|
||||
# Allocate seq
|
||||
assert block_manager.can_allocate(seq_group)
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
# Seq seq to RUNNING
|
||||
seq = seq_group.get_seqs()[0]
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
# Append tokens to the sequeqnce
|
||||
for token_id in range(num_slots_to_append):
|
||||
seq.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||
|
||||
# Append slots for new tokens and lookahead slots.
|
||||
free_blocks_before_append = block_manager.get_num_free_gpu_blocks()
|
||||
block_manager.append_slots(seq, num_lookahead_slots)
|
||||
num_consumed_blocks = (free_blocks_before_append -
|
||||
block_manager.get_num_free_gpu_blocks())
|
||||
|
||||
# Expect consumed blocks to be new blocks required to support the new slots.
|
||||
expected_consumed_blocks = len(
|
||||
chunk_list(
|
||||
list(
|
||||
range(prompt_len + num_slots_to_append + num_lookahead_slots)),
|
||||
block_size)) - len(chunk_list(list(range(prompt_len)), block_size))
|
||||
assert num_consumed_blocks == expected_consumed_blocks
|
@ -1,50 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
|
||||
from vllm.core.interfaces import AllocStatus
|
||||
|
||||
from ..utils import create_seq_group
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("num_gpu_blocks", [8, 40, 80])
|
||||
@pytest.mark.parametrize("num_seqs_per_group", [1, 4])
|
||||
@pytest.mark.parametrize("watermark", [0.0, 0.5])
|
||||
def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int,
|
||||
num_gpu_blocks: int, watermark: float):
|
||||
block_manager = BlockSpaceManagerV2(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=1024,
|
||||
watermark=watermark,
|
||||
)
|
||||
num_watermark_blocks = int(watermark * num_gpu_blocks)
|
||||
|
||||
num_output_blocks_per_seq = 1
|
||||
|
||||
# NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but
|
||||
# the current implementation assumes all seqs are new prompts / don't have
|
||||
# different output lens.
|
||||
num_output_blocks = num_output_blocks_per_seq
|
||||
|
||||
for num_prompt_blocks in range(1, num_gpu_blocks - num_output_blocks):
|
||||
seq_group = create_seq_group(
|
||||
seq_prompt_lens=block_size * num_prompt_blocks,
|
||||
seq_output_lens=[
|
||||
block_size * num_output_blocks_per_seq
|
||||
for _ in range(num_seqs_per_group)
|
||||
],
|
||||
)
|
||||
|
||||
assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks
|
||||
|
||||
can_allocate_result = block_manager.can_allocate(seq_group)
|
||||
|
||||
num_required_blocks = num_prompt_blocks + num_output_blocks
|
||||
|
||||
if num_gpu_blocks - num_required_blocks < num_watermark_blocks:
|
||||
assert can_allocate_result == AllocStatus.NEVER
|
||||
elif num_gpu_blocks >= num_required_blocks:
|
||||
assert can_allocate_result == AllocStatus.OK
|
||||
else:
|
||||
assert can_allocate_result == AllocStatus.LATER
|
@ -498,3 +498,78 @@ def test_cow_lookahead_simple(block_size: int, sequence_len: int,
|
||||
|
||||
# After free, expect all blocks to be freed.
|
||||
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [1, 8])
|
||||
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
|
||||
@pytest.mark.parametrize("num_new_tokens", [1, 16, 129])
|
||||
@pytest.mark.parametrize("num_lookahead_slots", [1, 7, 8])
|
||||
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
|
||||
def test_num_blocks_touched_by_append_slots(block_size: int, sequence_len: int,
|
||||
num_new_tokens: int,
|
||||
num_lookahead_slots: int,
|
||||
allocator_type: str):
|
||||
"""Verify correct calculation of get_num_blocks_touched_by_append_slots.
|
||||
|
||||
This is done by using copy-on-write, which requires any modified block to
|
||||
be copied before write if the refcount > 1. We set the refcount>1 by forking
|
||||
a sequence, then measure the free blocks before and after an append. If the
|
||||
number of consumed blocks equals what `get_num_blocks_touched_by_append_
|
||||
slots` returns, then the calculation is correct.
|
||||
"""
|
||||
|
||||
num_gpu_blocks = 1024
|
||||
|
||||
allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type=allocator_type,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=0,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
token_ids = list(range(sequence_len))
|
||||
token_ids_to_append = list(range(num_new_tokens))
|
||||
|
||||
block_table = BlockTable(
|
||||
block_size=block_size,
|
||||
block_allocator=allocator,
|
||||
)
|
||||
|
||||
block_table.allocate(token_ids=token_ids, device=Device.GPU)
|
||||
|
||||
# Add lookahead before fork so both sequences have the same lookahead
|
||||
# blocks.
|
||||
block_table.ensure_num_empty_slots(num_empty_slots=num_lookahead_slots)
|
||||
|
||||
# Fork sequence so that every block has refcount > 1.
|
||||
_ = block_table.fork()
|
||||
|
||||
# Determine how many blocks should be touched.
|
||||
expected_num_touched_blocks = (
|
||||
block_table.get_num_blocks_touched_by_append_slots(
|
||||
token_ids=token_ids_to_append,
|
||||
num_lookahead_slots=num_lookahead_slots))
|
||||
|
||||
# Measure how many blocks are touched by measuring num_free_blocks before
|
||||
# and after the append.
|
||||
#
|
||||
# We expect append_token_ids to CoW all mutated blocks that have refcount>1.
|
||||
num_free_blocks_before_append = allocator.get_num_free_blocks(Device.GPU)
|
||||
block_table.append_token_ids(token_ids_to_append, num_lookahead_slots)
|
||||
num_consumed_blocks = (num_free_blocks_before_append -
|
||||
allocator.get_num_free_blocks(Device.GPU))
|
||||
|
||||
# TODO(cade) ensure equality when num_lookahead_slots > 0.
|
||||
# The reason we have < is because lookahead blocks are not copied eagerly;
|
||||
# they are copied on first write. This will cause issues for beam search +
|
||||
# speculative decoding. This is acceptable for now as it is a large effort
|
||||
# to combine the two. To fix this, we can ensure single sequence ownership
|
||||
# of lookahead blocks by appending empty slots to each block, which will
|
||||
# trigger the CoW.
|
||||
#
|
||||
# Until then, we can accept that the consumed tokens are <= the expected
|
||||
# tokens when appending with lookahead.
|
||||
if num_lookahead_slots > 0:
|
||||
assert num_consumed_blocks <= expected_num_touched_blocks
|
||||
else:
|
||||
assert num_consumed_blocks == expected_num_touched_blocks
|
||||
|
@ -103,9 +103,9 @@ def test_append_slot_single_seq():
|
||||
block_manager.allocate(seq_group)
|
||||
|
||||
# Nothing to append. Sequence has no new logical blocks.
|
||||
assert block_manager.can_append_slot(seq_group)
|
||||
assert block_manager.can_append_slots(seq_group)
|
||||
before_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert not block_manager.append_slot(prompt)
|
||||
assert not block_manager.append_slots(prompt)
|
||||
after_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert before_blocks == after_blocks
|
||||
|
||||
@ -114,9 +114,9 @@ def test_append_slot_single_seq():
|
||||
token_id = i + 5
|
||||
prompt.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||
|
||||
assert block_manager.can_append_slot(seq_group)
|
||||
assert block_manager.can_append_slots(seq_group)
|
||||
before_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert not block_manager.append_slot(prompt)
|
||||
assert not block_manager.append_slots(prompt)
|
||||
after_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert before_blocks - after_blocks == 1
|
||||
|
||||
@ -150,13 +150,13 @@ def test_append_slot_cow():
|
||||
child.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||
block_manager.fork(prompt, child)
|
||||
|
||||
assert block_manager.can_append_slot(seq_group)
|
||||
assert block_manager.can_append_slots(seq_group)
|
||||
before_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
|
||||
maybe_src_dst_block = block_manager.append_slot(child)
|
||||
assert maybe_src_dst_block is not None
|
||||
src_block, dst_block = maybe_src_dst_block
|
||||
assert src_block != dst_block
|
||||
cows = block_manager.append_slots(child)
|
||||
assert cows
|
||||
for src_block, dst_blocks in cows.items():
|
||||
assert src_block not in dst_blocks
|
||||
|
||||
after_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert before_blocks - after_blocks == 1
|
||||
@ -184,7 +184,7 @@ def test_fork():
|
||||
token_id = 4
|
||||
# Append token to child. Block is shared so copy on write occurs.
|
||||
child.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||
block_manager.append_slot(child)
|
||||
block_manager.append_slots(child)
|
||||
assert block_manager.get_block_table(
|
||||
prompt) != block_manager.get_block_table(child)
|
||||
|
||||
@ -325,7 +325,7 @@ def test_sliding_window_multi_seq():
|
||||
token_id = 4
|
||||
# Append token to child. Block is shared so copy on write occurs.
|
||||
child.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||
block_manager.append_slot(child)
|
||||
block_manager.append_slots(child)
|
||||
|
||||
# assert the number of blocks allocated is correct
|
||||
# we will use now one block more. Each seq will use 2 blocks,
|
||||
@ -335,7 +335,7 @@ def test_sliding_window_multi_seq():
|
||||
|
||||
token_id = 5
|
||||
parent.append_token_id(token_id, {token_id: Logprob(0.0)})
|
||||
block_manager.append_slot(parent)
|
||||
block_manager.append_slots(parent)
|
||||
|
||||
# assert the number of blocks allocated is correct
|
||||
# no change, because both sequences are still just sharing one block
|
||||
|
@ -24,7 +24,7 @@ def create_dummy_prompt(
|
||||
|
||||
|
||||
def create_seq_group(
|
||||
seq_prompt_lens=1024,
|
||||
seq_prompt_len=1024,
|
||||
seq_output_lens=(128, ),
|
||||
request_id='0',
|
||||
seq_id_start=0,
|
||||
@ -32,7 +32,7 @@ def create_seq_group(
|
||||
|
||||
assert len(seq_output_lens) > 0
|
||||
|
||||
prompt_token_ids = [0] * seq_prompt_lens
|
||||
prompt_token_ids = [0] * seq_prompt_len
|
||||
|
||||
seqs = []
|
||||
for seq_id_offset, output_len in enumerate(seq_output_lens):
|
||||
|
@ -530,9 +530,13 @@ class SchedulerConfig:
|
||||
iteration.
|
||||
max_model_len: Maximum length of a sequence (including prompt
|
||||
and generated text).
|
||||
use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
|
||||
num_lookahead_slots: The number of slots to allocate per sequence per
|
||||
step, beyond the known token ids. This is used in speculative
|
||||
decoding to store KV activations of tokens which may or may not be
|
||||
accepted.
|
||||
delay_factor: Apply a delay (of delay factor multiplied by previous
|
||||
prompt latency) before scheduling next prompt.
|
||||
use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
|
||||
enable_chunked_prefill: If True, prefill requests can be chunked based
|
||||
on the remaining max_num_batched_tokens.
|
||||
"""
|
||||
@ -543,6 +547,7 @@ class SchedulerConfig:
|
||||
max_num_seqs: int,
|
||||
max_model_len: int,
|
||||
use_v2_block_manager: bool = False,
|
||||
num_lookahead_slots: int = 0,
|
||||
delay_factor: float = 0.0,
|
||||
enable_chunked_prefill: bool = False,
|
||||
) -> None:
|
||||
@ -554,9 +559,11 @@ class SchedulerConfig:
|
||||
self.max_num_batched_tokens = max(max_model_len, 2048)
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_model_len = max_model_len
|
||||
self.delay_factor = delay_factor
|
||||
self.use_v2_block_manager = use_v2_block_manager
|
||||
self.num_lookahead_slots = num_lookahead_slots
|
||||
self.delay_factor = delay_factor
|
||||
self.chunked_prefill_enabled = enable_chunked_prefill
|
||||
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
@ -568,12 +575,19 @@ class SchedulerConfig:
|
||||
"max_num_batched_tokens and makes vLLM reject longer "
|
||||
"sequences. Please increase max_num_batched_tokens or "
|
||||
"decrease max_model_len.")
|
||||
|
||||
if self.max_num_batched_tokens < self.max_num_seqs:
|
||||
raise ValueError(
|
||||
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
|
||||
"be greater than or equal to max_num_seqs "
|
||||
f"({self.max_num_seqs}).")
|
||||
|
||||
if self.num_lookahead_slots < 0:
|
||||
raise ValueError(
|
||||
"num_lookahead_slots "
|
||||
f"({self.num_lookahead_slots}) must be greater than or "
|
||||
"equal to 0.")
|
||||
|
||||
|
||||
class DeviceConfig:
|
||||
|
||||
|
@ -85,7 +85,9 @@ class BlockTable:
|
||||
device=device)
|
||||
self._num_full_slots = len(token_ids)
|
||||
|
||||
def append_token_ids(self, token_ids: List[int]) -> None:
|
||||
def append_token_ids(self,
|
||||
token_ids: List[int],
|
||||
num_lookahead_slots: int = 0) -> None:
|
||||
"""Appends a sequence of token IDs to the existing blocks in the
|
||||
BlockTable.
|
||||
|
||||
@ -102,14 +104,13 @@ class BlockTable:
|
||||
token_ids (List[int]): The sequence of token IDs to be appended.
|
||||
"""
|
||||
assert self._is_allocated
|
||||
assert token_ids, "can't append empty token ids"
|
||||
|
||||
self.ensure_num_empty_slots(num_empty_slots=len(token_ids))
|
||||
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
|
||||
num_lookahead_slots)
|
||||
|
||||
blocks = self._blocks[self._num_full_slots // self._block_size:]
|
||||
first_chunk_size = self._block_size - (self._num_full_slots %
|
||||
self._block_size)
|
||||
token_blocks = [token_ids[:first_chunk_size]] + chunk_list(
|
||||
token_ids[first_chunk_size:], self._block_size)
|
||||
token_blocks = self._chunk_token_blocks_for_append(token_ids)
|
||||
|
||||
for block, token_block in zip(blocks, token_blocks):
|
||||
block.append_token_ids(token_block)
|
||||
@ -195,6 +196,25 @@ class BlockTable:
|
||||
assert self._is_allocated
|
||||
return [block.block_id for block in self._blocks]
|
||||
|
||||
def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]:
|
||||
"""Get the number of "unseen" tokens in the sequence.
|
||||
|
||||
Unseen tokens are tokens in the sequence corresponding to this block
|
||||
table, but are not yet appended to this block table.
|
||||
|
||||
Args:
|
||||
sequence_token_ids (List[int]): The list of token ids in the
|
||||
sequence.
|
||||
|
||||
Returns:
|
||||
List[int]: The postfix of sequence_token_ids that has not yet been
|
||||
appended to the block table.
|
||||
"""
|
||||
|
||||
# Since the block table is append-only, the unseen token ids are the
|
||||
# ones after the appended ones.
|
||||
return sequence_token_ids[self.num_full_slots:]
|
||||
|
||||
def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
device: Device) -> List[Block]:
|
||||
@ -243,3 +263,29 @@ class BlockTable:
|
||||
int: The total number of tokens currently stored in the BlockTable.
|
||||
"""
|
||||
return self._num_full_slots
|
||||
|
||||
def get_num_blocks_touched_by_append_slots(
|
||||
self, token_ids: List[int], num_lookahead_slots: int) -> int:
|
||||
"""Determine how many blocks will be "touched" by appending the token
|
||||
ids.
|
||||
|
||||
This is required for the scheduler to determine whether a sequence can
|
||||
continue generation, or if it must be preempted.
|
||||
"""
|
||||
|
||||
all_token_ids = token_ids + [-1] * num_lookahead_slots
|
||||
token_blocks = self._chunk_token_blocks_for_append(all_token_ids)
|
||||
return len(token_blocks)
|
||||
|
||||
def _chunk_token_blocks_for_append(
|
||||
self, token_ids: List[int]) -> List[List[int]]:
|
||||
"""Split the token ids into block-sized chunks so they can be easily
|
||||
appended to blocks. The first such "token block" may have less token ids
|
||||
than the block size, since the last allocated block may be partially
|
||||
full.
|
||||
"""
|
||||
first_chunk_size = self._block_size - (self._num_full_slots %
|
||||
self._block_size)
|
||||
token_blocks = [token_ids[:first_chunk_size]] + chunk_list(
|
||||
token_ids[first_chunk_size:], self._block_size)
|
||||
return token_blocks
|
||||
|
@ -2,7 +2,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from itertools import count, takewhile
|
||||
from os.path import commonprefix
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
from vllm.block import BlockTable, PhysicalTokenBlock
|
||||
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
|
||||
@ -292,7 +292,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
|
||||
self.block_tables[seq.seq_id] = block_table.copy()
|
||||
|
||||
def can_append_slot(self, seq_group: SequenceGroup) -> bool:
|
||||
def can_append_slots(self,
|
||||
seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int = 0) -> bool:
|
||||
assert (num_lookahead_slots == 0
|
||||
), "lookahead allocation not supported in BlockSpaceManagerV1"
|
||||
|
||||
# Simple heuristic: If there is at least one free block
|
||||
# for each sequence, we can append.
|
||||
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||
@ -364,10 +369,11 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
assert new_block.ref_count == 1
|
||||
return new_block
|
||||
|
||||
def append_slot(
|
||||
def append_slots(
|
||||
self,
|
||||
seq: Sequence,
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
num_lookahead_slots: int = 0,
|
||||
) -> Dict[int, List[int]]:
|
||||
"""Allocate a physical slot for a new token."""
|
||||
logical_blocks = seq.logical_token_blocks
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
@ -386,7 +392,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
# Allocate a new physical block.
|
||||
new_block = self._allocate_last_physical_block(seq)
|
||||
block_table.append(new_block)
|
||||
return None
|
||||
return {}
|
||||
|
||||
# We want to append the token to the last physical block.
|
||||
last_block = block_table[-1]
|
||||
@ -399,7 +405,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
maybe_new_block = self._maybe_promote_last_block(
|
||||
seq, last_block)
|
||||
block_table[-1] = maybe_new_block
|
||||
return None
|
||||
return {}
|
||||
else:
|
||||
# The last block is shared with other sequences.
|
||||
# Copy on Write: Allocate a new block and copy the tokens.
|
||||
@ -407,7 +413,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
|
||||
block_table[-1] = new_block
|
||||
self.gpu_allocator.free(last_block)
|
||||
return last_block.block_number, new_block.block_number
|
||||
return {last_block.block_number: [new_block.block_number]}
|
||||
|
||||
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||
# NOTE: fork does not allocate a new physical block.
|
||||
@ -433,7 +439,11 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
blocks.update(self.block_tables[seq.seq_id])
|
||||
return list(blocks)
|
||||
|
||||
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
|
||||
def can_swap_in(self,
|
||||
seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int = 0) -> bool:
|
||||
assert (num_lookahead_slots == 0
|
||||
), "BlockSpaceManagerV1 does not support lookahead allocation"
|
||||
blocks = self._get_physical_blocks(seq_group)
|
||||
num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
|
||||
num_free_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||
@ -443,7 +453,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
num_required_blocks = len(blocks) + num_swapped_seqs
|
||||
return num_free_blocks - num_required_blocks >= self.watermark_blocks
|
||||
|
||||
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||
def swap_in(self,
|
||||
seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int = 0) -> Dict[int, int]:
|
||||
assert (num_lookahead_slots == 0
|
||||
), "BlockSpaceManagerV1 does not support lookahead allocation"
|
||||
|
||||
# CPU block -> GPU block.
|
||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""A block manager that manages token blocks."""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from vllm.core.block.block_table import BlockTable
|
||||
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
|
||||
@ -21,6 +21,24 @@ class BlockSpaceManagerV2(BlockSpaceManager):
|
||||
sliding-window are not feature complete. This class implements the design
|
||||
described in https://github.com/vllm-project/vllm/pull/3492.
|
||||
|
||||
Lookahead slots
|
||||
The block manager has the notion of a "lookahead slot". These are slots
|
||||
in the KV cache that are allocated for a sequence. Unlike the other
|
||||
allocated slots, the content of these slots is undefined -- the worker
|
||||
may use the memory allocations in any way.
|
||||
|
||||
In practice, a worker could use these lookahead slots to run multiple
|
||||
forward passes for a single scheduler invocation. Each successive
|
||||
forward pass would write KV activations to the corresponding lookahead
|
||||
slot. This allows low inter-token latency use-cases, where the overhead
|
||||
of continuous batching scheduling is amortized over >1 generated tokens.
|
||||
|
||||
Speculative decoding uses lookahead slots to store KV activations of
|
||||
proposal tokens.
|
||||
|
||||
See https://github.com/vllm-project/vllm/pull/3250 for more information
|
||||
on lookahead scheduling.
|
||||
|
||||
Args:
|
||||
block_size (int): The size of each memory block.
|
||||
num_gpu_blocks (int): The number of memory blocks allocated on GPU.
|
||||
@ -116,35 +134,51 @@ class BlockSpaceManagerV2(BlockSpaceManager):
|
||||
for seq in waiting_seqs[1:]:
|
||||
self.block_tables[seq.seq_id] = block_table.fork()
|
||||
|
||||
def can_append_slot(self, seq_group: SequenceGroup) -> bool:
|
||||
# Simple heuristic: If there is at least one free block
|
||||
# for each sequence, we can append.
|
||||
def can_append_slots(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> bool:
|
||||
"""Determine if there is enough space in the GPU KV cache to continue
|
||||
generation of the specified sequence group.
|
||||
|
||||
We use a worst-case heuristic: assume each touched block will require a
|
||||
new allocation (either via CoW or new block). We can append slots if the
|
||||
number of touched blocks is less than the number of free blocks.
|
||||
|
||||
"Lookahead slots" are slots that are allocated in addition to the slots
|
||||
for known tokens. The contents of the lookahead slots are not defined.
|
||||
This is used by speculative decoding when speculating future tokens.
|
||||
"""
|
||||
|
||||
num_touched_blocks = 0
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
||||
num_touched_blocks += (
|
||||
block_table.get_num_blocks_touched_by_append_slots(
|
||||
token_ids=block_table.get_unseen_token_ids(
|
||||
seq.get_token_ids()),
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
))
|
||||
|
||||
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
|
||||
Device.GPU)
|
||||
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
||||
return num_seqs <= num_free_gpu_blocks
|
||||
return num_touched_blocks <= num_free_gpu_blocks
|
||||
|
||||
def append_slot(
|
||||
def append_slots(
|
||||
self,
|
||||
seq: Sequence,
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
num_lookahead_slots: int,
|
||||
) -> Dict[int, List[int]]:
|
||||
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
||||
# Get unseen token ids.
|
||||
num_full_slots = block_table.num_full_slots
|
||||
unseen_token_ids = seq.get_token_ids()[num_full_slots:]
|
||||
assert unseen_token_ids
|
||||
block_table.append_token_ids(
|
||||
token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()),
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
)
|
||||
|
||||
block_table.append_token_ids(unseen_token_ids)
|
||||
|
||||
# Return any copy-on-writes.
|
||||
_ = self.block_allocator.clear_copy_on_writes()
|
||||
|
||||
# TODO extend append_slot interface to append_slots
|
||||
# @cadedaniel will do in https://github.com/vllm-project/vllm/pull/3250
|
||||
|
||||
return None
|
||||
# Return any new copy-on-writes.
|
||||
new_cows = self.block_allocator.clear_copy_on_writes()
|
||||
return new_cows
|
||||
|
||||
def free(self, seq: Sequence) -> None:
|
||||
if seq.seq_id not in self.block_tables:
|
||||
@ -191,10 +225,12 @@ class BlockSpaceManagerV2(BlockSpaceManager):
|
||||
src_block_table = self.block_tables[parent_seq.seq_id]
|
||||
self.block_tables[child_seq.seq_id] = src_block_table.fork()
|
||||
|
||||
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
|
||||
def can_swap_in(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> bool:
|
||||
return False
|
||||
|
||||
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||
def swap_in(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> Dict[int, int]:
|
||||
raise NotImplementedError
|
||||
|
||||
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
|
||||
|
@ -1,6 +1,6 @@
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List
|
||||
|
||||
from vllm.sequence import Sequence, SequenceGroup
|
||||
|
||||
@ -44,14 +44,16 @@ class BlockSpaceManager(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def can_append_slot(self, seq_group: SequenceGroup) -> bool:
|
||||
def can_append_slots(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def append_slot(
|
||||
def append_slots(
|
||||
self,
|
||||
seq: Sequence,
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
num_lookahead_slots: int,
|
||||
) -> Dict[int, List[int]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -59,11 +61,13 @@ class BlockSpaceManager(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
|
||||
def can_swap_in(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||
def swap_in(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> Dict[int, int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
@ -52,6 +52,7 @@ class SchedulerOutputs:
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
ignored_seq_groups: List[SequenceGroup],
|
||||
num_lookahead_slots: int,
|
||||
) -> None:
|
||||
"""A list of sequence groups to be scheduled as a single batch.
|
||||
|
||||
@ -86,6 +87,7 @@ class SchedulerOutputs:
|
||||
|
||||
# Swap in and swap out should never happen at the same time.
|
||||
assert not (blocks_to_swap_in and blocks_to_swap_out)
|
||||
self.num_lookahead_slots = num_lookahead_slots
|
||||
|
||||
self.num_loras: int = len(self.lora_requests)
|
||||
if self.num_loras > 0:
|
||||
@ -309,6 +311,8 @@ class Scheduler:
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
ignored_seq_groups=ignored_seq_groups,
|
||||
num_lookahead_slots=self._get_num_lookahead_slots(
|
||||
is_prefill=True),
|
||||
)
|
||||
return scheduler_outputs
|
||||
|
||||
@ -323,7 +327,7 @@ class Scheduler:
|
||||
preempted: List[SequenceGroup] = []
|
||||
while self.running:
|
||||
seq_group = self.running.popleft()
|
||||
while not self.block_manager.can_append_slot(seq_group):
|
||||
while not self._can_append_slots(seq_group):
|
||||
if self.running:
|
||||
# Preempt the lowest-priority sequence groups.
|
||||
victim_seq_group = self.running.pop()
|
||||
@ -337,7 +341,7 @@ class Scheduler:
|
||||
break
|
||||
else:
|
||||
# Append new slots to the sequence group.
|
||||
self._append_slot(seq_group, blocks_to_copy)
|
||||
self._append_slots(seq_group, blocks_to_copy)
|
||||
running.append(seq_group)
|
||||
self.running = running
|
||||
|
||||
@ -366,7 +370,7 @@ class Scheduler:
|
||||
continue
|
||||
|
||||
# If the sequence group cannot be swapped in, stop.
|
||||
if not self.block_manager.can_swap_in(seq_group):
|
||||
if not self._can_swap_in(seq_group):
|
||||
break
|
||||
|
||||
# The total number of sequences in the RUNNING state should not
|
||||
@ -380,7 +384,7 @@ class Scheduler:
|
||||
curr_loras.add(lora_int_id)
|
||||
self.swapped.popleft()
|
||||
self._swap_in(seq_group, blocks_to_swap_in)
|
||||
self._append_slot(seq_group, blocks_to_copy)
|
||||
self._append_slots(seq_group, blocks_to_copy)
|
||||
num_curr_seqs += num_new_seqs
|
||||
self.running.append(seq_group)
|
||||
|
||||
@ -405,9 +409,32 @@ class Scheduler:
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
ignored_seq_groups=[],
|
||||
num_lookahead_slots=self._get_num_lookahead_slots(
|
||||
is_prefill=False),
|
||||
)
|
||||
return scheduler_outputs
|
||||
|
||||
def _can_append_slots(self, seq_group: SequenceGroup) -> bool:
|
||||
"""Determine whether or not we have enough space in the KV cache to
|
||||
continue generation of the sequence group.
|
||||
"""
|
||||
# Appending slots only occurs in decoding.
|
||||
is_prefill = False
|
||||
|
||||
return self.block_manager.can_append_slots(
|
||||
seq_group=seq_group,
|
||||
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
|
||||
)
|
||||
|
||||
def _can_swap_in(self, seq_group: SequenceGroup) -> bool:
|
||||
# Swapping in is considered decode.
|
||||
is_prefill = False
|
||||
|
||||
return self.block_manager.can_swap_in(
|
||||
seq_group=seq_group,
|
||||
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
|
||||
)
|
||||
|
||||
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
|
||||
# Schedule sequence groups.
|
||||
# This function call changes the internal states of the scheduler
|
||||
@ -482,19 +509,30 @@ class Scheduler:
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
def _append_slot(
|
||||
def _append_slots(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
) -> None:
|
||||
"""Appends new slots to the sequences in the given sequence group.
|
||||
|
||||
Args:
|
||||
seq_group (SequenceGroup): The sequence group containing the
|
||||
sequences to append slots to.
|
||||
blocks_to_copy (Dict[int, List[int]]): A dictionary mapping source
|
||||
block indices to lists of destination block indices. This
|
||||
dictionary is updated with the new source and destination block
|
||||
indices for the appended slots.
|
||||
"""
|
||||
num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
|
||||
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
ret = self.block_manager.append_slot(seq)
|
||||
if ret is not None:
|
||||
src_block, dst_block = ret
|
||||
if src_block in blocks_to_copy:
|
||||
blocks_to_copy[src_block].append(dst_block)
|
||||
else:
|
||||
blocks_to_copy[src_block] = [dst_block]
|
||||
cows = self.block_manager.append_slots(seq, num_lookahead_slots)
|
||||
|
||||
for src, dests in cows.items():
|
||||
if src not in blocks_to_copy:
|
||||
blocks_to_copy[src] = []
|
||||
blocks_to_copy[src].extend(dests)
|
||||
|
||||
def _preempt(
|
||||
self,
|
||||
@ -588,3 +626,16 @@ class Scheduler:
|
||||
else:
|
||||
passed_delay = True
|
||||
return passed_delay
|
||||
|
||||
def _get_num_lookahead_slots(self, is_prefill: bool) -> int:
|
||||
"""The number of slots to allocate per sequence per step, beyond known
|
||||
token ids. Speculative decoding uses these slots to store KV activations
|
||||
of tokens which may or may not be accepted.
|
||||
|
||||
Speculative decoding does not yet support prefill, so we do not perform
|
||||
lookahead allocation for prefill.
|
||||
"""
|
||||
if is_prefill:
|
||||
return 0
|
||||
|
||||
return self.scheduler_config.num_lookahead_slots
|
||||
|
@ -53,8 +53,8 @@ class EngineArgs:
|
||||
max_cpu_loras: Optional[int] = None
|
||||
device: str = 'auto'
|
||||
ray_workers_use_nsight: bool = False
|
||||
|
||||
forced_num_gpu_blocks: Optional[int] = None
|
||||
num_lookahead_slots: int = 0
|
||||
|
||||
# Related to Vision-language models such as llava
|
||||
image_input_type: Optional[str] = None
|
||||
@ -202,6 +202,14 @@ class EngineArgs:
|
||||
parser.add_argument('--use-v2-block-manager',
|
||||
action='store_true',
|
||||
help='Use BlockSpaceMangerV2')
|
||||
parser.add_argument(
|
||||
'--num-lookahead-slots',
|
||||
type=int,
|
||||
default=EngineArgs.num_lookahead_slots,
|
||||
help='Experimental scheduling config necessary for '
|
||||
'speculative decoding. This will be replaced by '
|
||||
'speculative config in the future; it is present '
|
||||
'to enable correctness tests until then.')
|
||||
|
||||
parser.add_argument('--seed',
|
||||
type=int,
|
||||
@ -406,6 +414,7 @@ class EngineArgs:
|
||||
self.max_num_seqs,
|
||||
model_config.max_model_len,
|
||||
self.use_v2_block_manager,
|
||||
num_lookahead_slots=self.num_lookahead_slots,
|
||||
delay_factor=self.scheduler_delay_factor,
|
||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||
)
|
||||
|
Reference in New Issue
Block a user