[Speculative decoding 4/9] Lookahead scheduling for speculative decoding (#3250)

This commit is contained in:
Cade Daniel
2024-04-01 15:55:24 -07:00
committed by GitHub
parent ccb58b23e6
commit 93deb0b38f
13 changed files with 579 additions and 123 deletions

View File

@ -77,6 +77,159 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
assert baseline_token_ids == test_token_ids
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
"model": "facebook/opt-125m",
# skip cuda graph creation for fast test.
"enforce_eager": True,
# Use a large block size to trigger more copy-on-writes.
"block_size": 32,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"use_v2_block_manager": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
@pytest.mark.parametrize("batch_size", [10])
@pytest.mark.parametrize("seed", [1])
def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
test_llm_generator, batch_size):
"""Verify beam search equality with block manager v1 and v2.
This requires copy-on-writes; if the v1 and v2 output is the same, then
we have some confidence cow is working.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
use_beam_search=True,
best_of=2,
)
print('Getting token ids from block manager v1')
baseline_token_ids = get_token_ids_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
print('Getting token ids from block manager v2')
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
prompts, sampling_params)
for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
test_token_ids):
assert expected_token_ids == actual_token_ids
assert baseline_token_ids == test_token_ids
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
"model": "facebook/opt-125m",
# Our prompts will generate 128 tokens; since the prompts themselves are
# small, we don't need much KV space beyond 128.
"max_model_len": 160,
# skip cuda graph creation for fast test.
"enforce_eager": True,
# Lookahead scheduling only supported in v2 block manager.
"use_v2_block_manager": True,
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
"block_size": 16,
# Allow only 2 sequences of ~128 tokens in worst case.
# Note 8 = 128/block_size
"forced_num_gpu_blocks": 2 * (8 + 1),
},
{
"block_size": 8,
# Allow only 2 sequences of ~128 tokens in worst case.
# Note 16 = 128/block_size
"forced_num_gpu_blocks": 2 * (16 + 1),
}
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"num_lookahead_slots": 0,
}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[{
# We run one test with block_size < lookahead_slots, one test with
# block_size > lookahead_slots
"num_lookahead_slots": 10,
}])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator,
test_llm_generator,
batch_size):
"""Verify vLLM produces the same output with greedy sampling, when lookahead
scheduling is used vs. not.
Lookahead scheduling is not expected to modify the output, as it simply
allocates empty slots ahead of the known token ids in a sliding fashion.
This test constrains the total number of blocks to force preemption. It also
varies the block size so that the lookahead size is less than and greater
than the block size.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
print('Getting token ids without lookahead scheduling')
baseline_token_ids = get_token_ids_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
print('Getting token ids with lookahead scheduling')
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
prompts, sampling_params)
for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
test_token_ids):
assert expected_token_ids == actual_token_ids
assert baseline_token_ids == test_token_ids
def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
for llm in llm_generator:
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)

View File

@ -0,0 +1,103 @@
import pytest
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
from vllm.core.interfaces import AllocStatus
from vllm.sequence import Logprob, SequenceStatus
from vllm.utils import chunk_list
from ..utils import create_seq_group
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_gpu_blocks", [8, 40, 80])
@pytest.mark.parametrize("num_seqs_per_group", [1, 4])
@pytest.mark.parametrize("watermark", [0.0, 0.5])
def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int,
num_gpu_blocks: int, watermark: float):
block_manager = BlockSpaceManagerV2(
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=1024,
watermark=watermark,
)
num_watermark_blocks = int(watermark * num_gpu_blocks)
num_output_blocks_per_seq = 1
# NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but
# the current implementation assumes all seqs are new prompts / don't have
# different output lens.
num_output_blocks = num_output_blocks_per_seq
for num_prompt_blocks in range(1, num_gpu_blocks - num_output_blocks):
seq_group = create_seq_group(
seq_prompt_len=block_size * num_prompt_blocks,
seq_output_lens=[
block_size * num_output_blocks_per_seq
for _ in range(num_seqs_per_group)
],
)
assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks
can_allocate_result = block_manager.can_allocate(seq_group)
num_required_blocks = num_prompt_blocks + num_output_blocks
if num_gpu_blocks - num_required_blocks < num_watermark_blocks:
assert can_allocate_result == AllocStatus.NEVER
elif num_gpu_blocks >= num_required_blocks:
assert can_allocate_result == AllocStatus.OK
else:
assert can_allocate_result == AllocStatus.LATER
@pytest.mark.parametrize("block_size", [1, 8])
@pytest.mark.parametrize("prompt_len", [1, 7, 8])
@pytest.mark.parametrize("num_slots_to_append", [1, 8, 129])
@pytest.mark.parametrize("num_lookahead_slots", [0, 10])
def test_append_slots(block_size, prompt_len, num_slots_to_append,
num_lookahead_slots):
"""Verify append_slots consumes the correct number of blocks from the block
table.
"""
num_gpu_blocks = 1024
watermark = 0.1
block_manager = BlockSpaceManagerV2(
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=0,
watermark=watermark,
)
seq_group = create_seq_group(
seq_prompt_len=prompt_len,
seq_output_lens=[0],
)
# Allocate seq
assert block_manager.can_allocate(seq_group)
block_manager.allocate(seq_group)
# Seq seq to RUNNING
seq = seq_group.get_seqs()[0]
seq.status = SequenceStatus.RUNNING
# Append tokens to the sequeqnce
for token_id in range(num_slots_to_append):
seq.append_token_id(token_id, {token_id: Logprob(0.0)})
# Append slots for new tokens and lookahead slots.
free_blocks_before_append = block_manager.get_num_free_gpu_blocks()
block_manager.append_slots(seq, num_lookahead_slots)
num_consumed_blocks = (free_blocks_before_append -
block_manager.get_num_free_gpu_blocks())
# Expect consumed blocks to be new blocks required to support the new slots.
expected_consumed_blocks = len(
chunk_list(
list(
range(prompt_len + num_slots_to_append + num_lookahead_slots)),
block_size)) - len(chunk_list(list(range(prompt_len)), block_size))
assert num_consumed_blocks == expected_consumed_blocks

View File

@ -1,50 +0,0 @@
import pytest
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
from vllm.core.interfaces import AllocStatus
from ..utils import create_seq_group
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_gpu_blocks", [8, 40, 80])
@pytest.mark.parametrize("num_seqs_per_group", [1, 4])
@pytest.mark.parametrize("watermark", [0.0, 0.5])
def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int,
num_gpu_blocks: int, watermark: float):
block_manager = BlockSpaceManagerV2(
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=1024,
watermark=watermark,
)
num_watermark_blocks = int(watermark * num_gpu_blocks)
num_output_blocks_per_seq = 1
# NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but
# the current implementation assumes all seqs are new prompts / don't have
# different output lens.
num_output_blocks = num_output_blocks_per_seq
for num_prompt_blocks in range(1, num_gpu_blocks - num_output_blocks):
seq_group = create_seq_group(
seq_prompt_lens=block_size * num_prompt_blocks,
seq_output_lens=[
block_size * num_output_blocks_per_seq
for _ in range(num_seqs_per_group)
],
)
assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks
can_allocate_result = block_manager.can_allocate(seq_group)
num_required_blocks = num_prompt_blocks + num_output_blocks
if num_gpu_blocks - num_required_blocks < num_watermark_blocks:
assert can_allocate_result == AllocStatus.NEVER
elif num_gpu_blocks >= num_required_blocks:
assert can_allocate_result == AllocStatus.OK
else:
assert can_allocate_result == AllocStatus.LATER

View File

@ -498,3 +498,78 @@ def test_cow_lookahead_simple(block_size: int, sequence_len: int,
# After free, expect all blocks to be freed.
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
@pytest.mark.parametrize("block_size", [1, 8])
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
@pytest.mark.parametrize("num_new_tokens", [1, 16, 129])
@pytest.mark.parametrize("num_lookahead_slots", [1, 7, 8])
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
def test_num_blocks_touched_by_append_slots(block_size: int, sequence_len: int,
num_new_tokens: int,
num_lookahead_slots: int,
allocator_type: str):
"""Verify correct calculation of get_num_blocks_touched_by_append_slots.
This is done by using copy-on-write, which requires any modified block to
be copied before write if the refcount > 1. We set the refcount>1 by forking
a sequence, then measure the free blocks before and after an append. If the
number of consumed blocks equals what `get_num_blocks_touched_by_append_
slots` returns, then the calculation is correct.
"""
num_gpu_blocks = 1024
allocator = CpuGpuBlockAllocator.create(
allocator_type=allocator_type,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=0,
block_size=block_size,
)
token_ids = list(range(sequence_len))
token_ids_to_append = list(range(num_new_tokens))
block_table = BlockTable(
block_size=block_size,
block_allocator=allocator,
)
block_table.allocate(token_ids=token_ids, device=Device.GPU)
# Add lookahead before fork so both sequences have the same lookahead
# blocks.
block_table.ensure_num_empty_slots(num_empty_slots=num_lookahead_slots)
# Fork sequence so that every block has refcount > 1.
_ = block_table.fork()
# Determine how many blocks should be touched.
expected_num_touched_blocks = (
block_table.get_num_blocks_touched_by_append_slots(
token_ids=token_ids_to_append,
num_lookahead_slots=num_lookahead_slots))
# Measure how many blocks are touched by measuring num_free_blocks before
# and after the append.
#
# We expect append_token_ids to CoW all mutated blocks that have refcount>1.
num_free_blocks_before_append = allocator.get_num_free_blocks(Device.GPU)
block_table.append_token_ids(token_ids_to_append, num_lookahead_slots)
num_consumed_blocks = (num_free_blocks_before_append -
allocator.get_num_free_blocks(Device.GPU))
# TODO(cade) ensure equality when num_lookahead_slots > 0.
# The reason we have < is because lookahead blocks are not copied eagerly;
# they are copied on first write. This will cause issues for beam search +
# speculative decoding. This is acceptable for now as it is a large effort
# to combine the two. To fix this, we can ensure single sequence ownership
# of lookahead blocks by appending empty slots to each block, which will
# trigger the CoW.
#
# Until then, we can accept that the consumed tokens are <= the expected
# tokens when appending with lookahead.
if num_lookahead_slots > 0:
assert num_consumed_blocks <= expected_num_touched_blocks
else:
assert num_consumed_blocks == expected_num_touched_blocks

View File

@ -103,9 +103,9 @@ def test_append_slot_single_seq():
block_manager.allocate(seq_group)
# Nothing to append. Sequence has no new logical blocks.
assert block_manager.can_append_slot(seq_group)
assert block_manager.can_append_slots(seq_group)
before_blocks = block_manager.get_num_free_gpu_blocks()
assert not block_manager.append_slot(prompt)
assert not block_manager.append_slots(prompt)
after_blocks = block_manager.get_num_free_gpu_blocks()
assert before_blocks == after_blocks
@ -114,9 +114,9 @@ def test_append_slot_single_seq():
token_id = i + 5
prompt.append_token_id(token_id, {token_id: Logprob(0.0)})
assert block_manager.can_append_slot(seq_group)
assert block_manager.can_append_slots(seq_group)
before_blocks = block_manager.get_num_free_gpu_blocks()
assert not block_manager.append_slot(prompt)
assert not block_manager.append_slots(prompt)
after_blocks = block_manager.get_num_free_gpu_blocks()
assert before_blocks - after_blocks == 1
@ -150,13 +150,13 @@ def test_append_slot_cow():
child.append_token_id(token_id, {token_id: Logprob(0.0)})
block_manager.fork(prompt, child)
assert block_manager.can_append_slot(seq_group)
assert block_manager.can_append_slots(seq_group)
before_blocks = block_manager.get_num_free_gpu_blocks()
maybe_src_dst_block = block_manager.append_slot(child)
assert maybe_src_dst_block is not None
src_block, dst_block = maybe_src_dst_block
assert src_block != dst_block
cows = block_manager.append_slots(child)
assert cows
for src_block, dst_blocks in cows.items():
assert src_block not in dst_blocks
after_blocks = block_manager.get_num_free_gpu_blocks()
assert before_blocks - after_blocks == 1
@ -184,7 +184,7 @@ def test_fork():
token_id = 4
# Append token to child. Block is shared so copy on write occurs.
child.append_token_id(token_id, {token_id: Logprob(0.0)})
block_manager.append_slot(child)
block_manager.append_slots(child)
assert block_manager.get_block_table(
prompt) != block_manager.get_block_table(child)
@ -325,7 +325,7 @@ def test_sliding_window_multi_seq():
token_id = 4
# Append token to child. Block is shared so copy on write occurs.
child.append_token_id(token_id, {token_id: Logprob(0.0)})
block_manager.append_slot(child)
block_manager.append_slots(child)
# assert the number of blocks allocated is correct
# we will use now one block more. Each seq will use 2 blocks,
@ -335,7 +335,7 @@ def test_sliding_window_multi_seq():
token_id = 5
parent.append_token_id(token_id, {token_id: Logprob(0.0)})
block_manager.append_slot(parent)
block_manager.append_slots(parent)
# assert the number of blocks allocated is correct
# no change, because both sequences are still just sharing one block

View File

@ -24,7 +24,7 @@ def create_dummy_prompt(
def create_seq_group(
seq_prompt_lens=1024,
seq_prompt_len=1024,
seq_output_lens=(128, ),
request_id='0',
seq_id_start=0,
@ -32,7 +32,7 @@ def create_seq_group(
assert len(seq_output_lens) > 0
prompt_token_ids = [0] * seq_prompt_lens
prompt_token_ids = [0] * seq_prompt_len
seqs = []
for seq_id_offset, output_len in enumerate(seq_output_lens):

View File

@ -530,9 +530,13 @@ class SchedulerConfig:
iteration.
max_model_len: Maximum length of a sequence (including prompt
and generated text).
use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
num_lookahead_slots: The number of slots to allocate per sequence per
step, beyond the known token ids. This is used in speculative
decoding to store KV activations of tokens which may or may not be
accepted.
delay_factor: Apply a delay (of delay factor multiplied by previous
prompt latency) before scheduling next prompt.
use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.
"""
@ -543,6 +547,7 @@ class SchedulerConfig:
max_num_seqs: int,
max_model_len: int,
use_v2_block_manager: bool = False,
num_lookahead_slots: int = 0,
delay_factor: float = 0.0,
enable_chunked_prefill: bool = False,
) -> None:
@ -554,9 +559,11 @@ class SchedulerConfig:
self.max_num_batched_tokens = max(max_model_len, 2048)
self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
self.delay_factor = delay_factor
self.use_v2_block_manager = use_v2_block_manager
self.num_lookahead_slots = num_lookahead_slots
self.delay_factor = delay_factor
self.chunked_prefill_enabled = enable_chunked_prefill
self._verify_args()
def _verify_args(self) -> None:
@ -568,12 +575,19 @@ class SchedulerConfig:
"max_num_batched_tokens and makes vLLM reject longer "
"sequences. Please increase max_num_batched_tokens or "
"decrease max_model_len.")
if self.max_num_batched_tokens < self.max_num_seqs:
raise ValueError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
"be greater than or equal to max_num_seqs "
f"({self.max_num_seqs}).")
if self.num_lookahead_slots < 0:
raise ValueError(
"num_lookahead_slots "
f"({self.num_lookahead_slots}) must be greater than or "
"equal to 0.")
class DeviceConfig:

View File

@ -85,7 +85,9 @@ class BlockTable:
device=device)
self._num_full_slots = len(token_ids)
def append_token_ids(self, token_ids: List[int]) -> None:
def append_token_ids(self,
token_ids: List[int],
num_lookahead_slots: int = 0) -> None:
"""Appends a sequence of token IDs to the existing blocks in the
BlockTable.
@ -102,14 +104,13 @@ class BlockTable:
token_ids (List[int]): The sequence of token IDs to be appended.
"""
assert self._is_allocated
assert token_ids, "can't append empty token ids"
self.ensure_num_empty_slots(num_empty_slots=len(token_ids))
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
num_lookahead_slots)
blocks = self._blocks[self._num_full_slots // self._block_size:]
first_chunk_size = self._block_size - (self._num_full_slots %
self._block_size)
token_blocks = [token_ids[:first_chunk_size]] + chunk_list(
token_ids[first_chunk_size:], self._block_size)
token_blocks = self._chunk_token_blocks_for_append(token_ids)
for block, token_block in zip(blocks, token_blocks):
block.append_token_ids(token_block)
@ -195,6 +196,25 @@ class BlockTable:
assert self._is_allocated
return [block.block_id for block in self._blocks]
def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]:
"""Get the number of "unseen" tokens in the sequence.
Unseen tokens are tokens in the sequence corresponding to this block
table, but are not yet appended to this block table.
Args:
sequence_token_ids (List[int]): The list of token ids in the
sequence.
Returns:
List[int]: The postfix of sequence_token_ids that has not yet been
appended to the block table.
"""
# Since the block table is append-only, the unseen token ids are the
# ones after the appended ones.
return sequence_token_ids[self.num_full_slots:]
def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block],
token_ids: List[int],
device: Device) -> List[Block]:
@ -243,3 +263,29 @@ class BlockTable:
int: The total number of tokens currently stored in the BlockTable.
"""
return self._num_full_slots
def get_num_blocks_touched_by_append_slots(
self, token_ids: List[int], num_lookahead_slots: int) -> int:
"""Determine how many blocks will be "touched" by appending the token
ids.
This is required for the scheduler to determine whether a sequence can
continue generation, or if it must be preempted.
"""
all_token_ids = token_ids + [-1] * num_lookahead_slots
token_blocks = self._chunk_token_blocks_for_append(all_token_ids)
return len(token_blocks)
def _chunk_token_blocks_for_append(
self, token_ids: List[int]) -> List[List[int]]:
"""Split the token ids into block-sized chunks so they can be easily
appended to blocks. The first such "token block" may have less token ids
than the block size, since the last allocated block may be partially
full.
"""
first_chunk_size = self._block_size - (self._num_full_slots %
self._block_size)
token_blocks = [token_ids[:first_chunk_size]] + chunk_list(
token_ids[first_chunk_size:], self._block_size)
return token_blocks

View File

@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from itertools import count, takewhile
from os.path import commonprefix
from typing import Dict, List, Optional, Set, Tuple
from typing import Dict, List, Optional, Set
from vllm.block import BlockTable, PhysicalTokenBlock
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
@ -292,7 +292,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
self.block_tables[seq.seq_id] = block_table.copy()
def can_append_slot(self, seq_group: SequenceGroup) -> bool:
def can_append_slots(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> bool:
assert (num_lookahead_slots == 0
), "lookahead allocation not supported in BlockSpaceManagerV1"
# Simple heuristic: If there is at least one free block
# for each sequence, we can append.
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
@ -364,10 +369,11 @@ class BlockSpaceManagerV1(BlockSpaceManager):
assert new_block.ref_count == 1
return new_block
def append_slot(
def append_slots(
self,
seq: Sequence,
) -> Optional[Tuple[int, int]]:
num_lookahead_slots: int = 0,
) -> Dict[int, List[int]]:
"""Allocate a physical slot for a new token."""
logical_blocks = seq.logical_token_blocks
block_table = self.block_tables[seq.seq_id]
@ -386,7 +392,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# Allocate a new physical block.
new_block = self._allocate_last_physical_block(seq)
block_table.append(new_block)
return None
return {}
# We want to append the token to the last physical block.
last_block = block_table[-1]
@ -399,7 +405,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
maybe_new_block = self._maybe_promote_last_block(
seq, last_block)
block_table[-1] = maybe_new_block
return None
return {}
else:
# The last block is shared with other sequences.
# Copy on Write: Allocate a new block and copy the tokens.
@ -407,7 +413,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
block_table[-1] = new_block
self.gpu_allocator.free(last_block)
return last_block.block_number, new_block.block_number
return {last_block.block_number: [new_block.block_number]}
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
# NOTE: fork does not allocate a new physical block.
@ -433,7 +439,11 @@ class BlockSpaceManagerV1(BlockSpaceManager):
blocks.update(self.block_tables[seq.seq_id])
return list(blocks)
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
def can_swap_in(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> bool:
assert (num_lookahead_slots == 0
), "BlockSpaceManagerV1 does not support lookahead allocation"
blocks = self._get_physical_blocks(seq_group)
num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
num_free_blocks = self.gpu_allocator.get_num_free_blocks()
@ -443,7 +453,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
num_required_blocks = len(blocks) + num_swapped_seqs
return num_free_blocks - num_required_blocks >= self.watermark_blocks
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
def swap_in(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> Dict[int, int]:
assert (num_lookahead_slots == 0
), "BlockSpaceManagerV1 does not support lookahead allocation"
# CPU block -> GPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):

View File

@ -1,5 +1,5 @@
"""A block manager that manages token blocks."""
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional
from vllm.core.block.block_table import BlockTable
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
@ -21,6 +21,24 @@ class BlockSpaceManagerV2(BlockSpaceManager):
sliding-window are not feature complete. This class implements the design
described in https://github.com/vllm-project/vllm/pull/3492.
Lookahead slots
The block manager has the notion of a "lookahead slot". These are slots
in the KV cache that are allocated for a sequence. Unlike the other
allocated slots, the content of these slots is undefined -- the worker
may use the memory allocations in any way.
In practice, a worker could use these lookahead slots to run multiple
forward passes for a single scheduler invocation. Each successive
forward pass would write KV activations to the corresponding lookahead
slot. This allows low inter-token latency use-cases, where the overhead
of continuous batching scheduling is amortized over >1 generated tokens.
Speculative decoding uses lookahead slots to store KV activations of
proposal tokens.
See https://github.com/vllm-project/vllm/pull/3250 for more information
on lookahead scheduling.
Args:
block_size (int): The size of each memory block.
num_gpu_blocks (int): The number of memory blocks allocated on GPU.
@ -116,35 +134,51 @@ class BlockSpaceManagerV2(BlockSpaceManager):
for seq in waiting_seqs[1:]:
self.block_tables[seq.seq_id] = block_table.fork()
def can_append_slot(self, seq_group: SequenceGroup) -> bool:
# Simple heuristic: If there is at least one free block
# for each sequence, we can append.
def can_append_slots(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> bool:
"""Determine if there is enough space in the GPU KV cache to continue
generation of the specified sequence group.
We use a worst-case heuristic: assume each touched block will require a
new allocation (either via CoW or new block). We can append slots if the
number of touched blocks is less than the number of free blocks.
"Lookahead slots" are slots that are allocated in addition to the slots
for known tokens. The contents of the lookahead slots are not defined.
This is used by speculative decoding when speculating future tokens.
"""
num_touched_blocks = 0
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
block_table = self.block_tables[seq.seq_id]
num_touched_blocks += (
block_table.get_num_blocks_touched_by_append_slots(
token_ids=block_table.get_unseen_token_ids(
seq.get_token_ids()),
num_lookahead_slots=num_lookahead_slots,
))
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
Device.GPU)
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
return num_seqs <= num_free_gpu_blocks
return num_touched_blocks <= num_free_gpu_blocks
def append_slot(
def append_slots(
self,
seq: Sequence,
) -> Optional[Tuple[int, int]]:
num_lookahead_slots: int,
) -> Dict[int, List[int]]:
block_table = self.block_tables[seq.seq_id]
# Get unseen token ids.
num_full_slots = block_table.num_full_slots
unseen_token_ids = seq.get_token_ids()[num_full_slots:]
assert unseen_token_ids
block_table.append_token_ids(
token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()),
num_lookahead_slots=num_lookahead_slots,
)
block_table.append_token_ids(unseen_token_ids)
# Return any copy-on-writes.
_ = self.block_allocator.clear_copy_on_writes()
# TODO extend append_slot interface to append_slots
# @cadedaniel will do in https://github.com/vllm-project/vllm/pull/3250
return None
# Return any new copy-on-writes.
new_cows = self.block_allocator.clear_copy_on_writes()
return new_cows
def free(self, seq: Sequence) -> None:
if seq.seq_id not in self.block_tables:
@ -191,10 +225,12 @@ class BlockSpaceManagerV2(BlockSpaceManager):
src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.fork()
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
def can_swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> bool:
return False
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
def swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> Dict[int, int]:
raise NotImplementedError
def can_swap_out(self, seq_group: SequenceGroup) -> bool:

View File

@ -1,6 +1,6 @@
import enum
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple
from typing import Dict, List
from vllm.sequence import Sequence, SequenceGroup
@ -44,14 +44,16 @@ class BlockSpaceManager(ABC):
pass
@abstractmethod
def can_append_slot(self, seq_group: SequenceGroup) -> bool:
def can_append_slots(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> bool:
pass
@abstractmethod
def append_slot(
def append_slots(
self,
seq: Sequence,
) -> Optional[Tuple[int, int]]:
num_lookahead_slots: int,
) -> Dict[int, List[int]]:
pass
@abstractmethod
@ -59,11 +61,13 @@ class BlockSpaceManager(ABC):
pass
@abstractmethod
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
def can_swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> bool:
pass
@abstractmethod
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
def swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> Dict[int, int]:
pass
@abstractmethod

View File

@ -52,6 +52,7 @@ class SchedulerOutputs:
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
ignored_seq_groups: List[SequenceGroup],
num_lookahead_slots: int,
) -> None:
"""A list of sequence groups to be scheduled as a single batch.
@ -86,6 +87,7 @@ class SchedulerOutputs:
# Swap in and swap out should never happen at the same time.
assert not (blocks_to_swap_in and blocks_to_swap_out)
self.num_lookahead_slots = num_lookahead_slots
self.num_loras: int = len(self.lora_requests)
if self.num_loras > 0:
@ -309,6 +311,8 @@ class Scheduler:
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
ignored_seq_groups=ignored_seq_groups,
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=True),
)
return scheduler_outputs
@ -323,7 +327,7 @@ class Scheduler:
preempted: List[SequenceGroup] = []
while self.running:
seq_group = self.running.popleft()
while not self.block_manager.can_append_slot(seq_group):
while not self._can_append_slots(seq_group):
if self.running:
# Preempt the lowest-priority sequence groups.
victim_seq_group = self.running.pop()
@ -337,7 +341,7 @@ class Scheduler:
break
else:
# Append new slots to the sequence group.
self._append_slot(seq_group, blocks_to_copy)
self._append_slots(seq_group, blocks_to_copy)
running.append(seq_group)
self.running = running
@ -366,7 +370,7 @@ class Scheduler:
continue
# If the sequence group cannot be swapped in, stop.
if not self.block_manager.can_swap_in(seq_group):
if not self._can_swap_in(seq_group):
break
# The total number of sequences in the RUNNING state should not
@ -380,7 +384,7 @@ class Scheduler:
curr_loras.add(lora_int_id)
self.swapped.popleft()
self._swap_in(seq_group, blocks_to_swap_in)
self._append_slot(seq_group, blocks_to_copy)
self._append_slots(seq_group, blocks_to_copy)
num_curr_seqs += num_new_seqs
self.running.append(seq_group)
@ -405,9 +409,32 @@ class Scheduler:
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
ignored_seq_groups=[],
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=False),
)
return scheduler_outputs
def _can_append_slots(self, seq_group: SequenceGroup) -> bool:
"""Determine whether or not we have enough space in the KV cache to
continue generation of the sequence group.
"""
# Appending slots only occurs in decoding.
is_prefill = False
return self.block_manager.can_append_slots(
seq_group=seq_group,
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
)
def _can_swap_in(self, seq_group: SequenceGroup) -> bool:
# Swapping in is considered decode.
is_prefill = False
return self.block_manager.can_swap_in(
seq_group=seq_group,
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
)
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
@ -482,19 +509,30 @@ class Scheduler:
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
seq.status = SequenceStatus.RUNNING
def _append_slot(
def _append_slots(
self,
seq_group: SequenceGroup,
blocks_to_copy: Dict[int, List[int]],
) -> None:
"""Appends new slots to the sequences in the given sequence group.
Args:
seq_group (SequenceGroup): The sequence group containing the
sequences to append slots to.
blocks_to_copy (Dict[int, List[int]]): A dictionary mapping source
block indices to lists of destination block indices. This
dictionary is updated with the new source and destination block
indices for the appended slots.
"""
num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
ret = self.block_manager.append_slot(seq)
if ret is not None:
src_block, dst_block = ret
if src_block in blocks_to_copy:
blocks_to_copy[src_block].append(dst_block)
else:
blocks_to_copy[src_block] = [dst_block]
cows = self.block_manager.append_slots(seq, num_lookahead_slots)
for src, dests in cows.items():
if src not in blocks_to_copy:
blocks_to_copy[src] = []
blocks_to_copy[src].extend(dests)
def _preempt(
self,
@ -588,3 +626,16 @@ class Scheduler:
else:
passed_delay = True
return passed_delay
def _get_num_lookahead_slots(self, is_prefill: bool) -> int:
"""The number of slots to allocate per sequence per step, beyond known
token ids. Speculative decoding uses these slots to store KV activations
of tokens which may or may not be accepted.
Speculative decoding does not yet support prefill, so we do not perform
lookahead allocation for prefill.
"""
if is_prefill:
return 0
return self.scheduler_config.num_lookahead_slots

View File

@ -53,8 +53,8 @@ class EngineArgs:
max_cpu_loras: Optional[int] = None
device: str = 'auto'
ray_workers_use_nsight: bool = False
forced_num_gpu_blocks: Optional[int] = None
num_lookahead_slots: int = 0
# Related to Vision-language models such as llava
image_input_type: Optional[str] = None
@ -202,6 +202,14 @@ class EngineArgs:
parser.add_argument('--use-v2-block-manager',
action='store_true',
help='Use BlockSpaceMangerV2')
parser.add_argument(
'--num-lookahead-slots',
type=int,
default=EngineArgs.num_lookahead_slots,
help='Experimental scheduling config necessary for '
'speculative decoding. This will be replaced by '
'speculative config in the future; it is present '
'to enable correctness tests until then.')
parser.add_argument('--seed',
type=int,
@ -406,6 +414,7 @@ class EngineArgs:
self.max_num_seqs,
model_config.max_model_len,
self.use_v2_block_manager,
num_lookahead_slots=self.num_lookahead_slots,
delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill,
)