Compare commits

...

4 Commits

Author SHA1 Message Date
d5bf492f16 Merge branch 'main' into optimize-prefix-caching-scheduling 2024-06-04 00:20:15 +00:00
8c7bab79f5 simplify code 2024-06-03 03:36:38 +00:00
1936d7bab0 format 2024-06-02 00:02:54 +00:00
996cf2de5c Fix hashing logic for non-full blocks 2024-06-02 00:01:30 +00:00
4 changed files with 47 additions and 49 deletions

View File

@ -1,5 +1,5 @@
"""Token blocks."""
from typing import List
from typing import List, Optional
from vllm.utils import Device
@ -25,6 +25,7 @@ class LogicalTokenBlock:
self.token_ids = [_BLANK_TOKEN_ID] * block_size
self.num_tokens = 0
self.block_hash: Optional[int] = None
def is_empty(self) -> bool:
return self.num_tokens == 0

View File

@ -262,8 +262,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
self.cross_block_tables: Dict[str, BlockTable] = {}
def _get_seq_num_required_blocks(self, seq: Sequence) -> int:
return 0 if seq is None \
else len(seq.logical_token_blocks)
return 0 if seq is None else len(seq.logical_token_blocks)
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
@ -275,8 +274,8 @@ class BlockSpaceManagerV1(BlockSpaceManager):
seq_group.get_seqs(status=SequenceStatus.WAITING)[0])
cross_num_required_blocks = self._get_seq_num_required_blocks(
seq_group.get_encoder_seq())
num_required_blocks = self_num_required_blocks + \
cross_num_required_blocks
num_required_blocks = (self_num_required_blocks +
cross_num_required_blocks)
if self.block_sliding_window is not None:
@ -293,9 +292,9 @@ class BlockSpaceManagerV1(BlockSpaceManager):
else:
return AllocStatus.LATER
def _allocate_sequence(self, \
seq: Sequence, \
ref_count: int, \
def _allocate_sequence(self,
seq: Sequence,
ref_count: int,
is_encoder_decoder: bool = True) -> BlockTable:
# Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks = len(seq.logical_token_blocks)
@ -328,10 +327,8 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# NOTE: Here we assume that all sequences in the group have the same
# decoder prompt.
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
block_table: BlockTable = \
self._allocate_sequence(seq,
seq_group.num_seqs(),
is_encoder_decoder)
block_table: BlockTable = self._allocate_sequence(
seq, seq_group.num_seqs(), is_encoder_decoder)
# Assign the self-attention block tables for each sequence.
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
@ -368,6 +365,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# Compute a new hash for the block so that it can be shared by other
# Sequences
new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
assert new_hash is not None, "Last block is not full."
# if new_hash is already in the cached table, then free last_block
# and return the cached version
@ -406,9 +404,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# content hash.
if not self.enable_caching:
return self.gpu_allocator.allocate()
block_hash: Optional[int] = None
if (self._is_last_block_full(seq)):
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
num_hashed_tokens = seq.num_hashed_tokens_of_block(
len(seq.logical_token_blocks) - 1)
@ -549,18 +545,14 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# dict is efficient in lookup `if cpu_block in mapping`
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
self.block_tables[seq.seq_id] = \
self._swap_block_table(self.block_tables[seq.seq_id],
self.cpu_allocator,
self.gpu_allocator,
mapping)
self.block_tables[seq.seq_id] = self._swap_block_table(
self.block_tables[seq.seq_id], self.cpu_allocator,
self.gpu_allocator, mapping)
if seq_group.is_encoder_decoder():
self.cross_block_tables[request_id] = \
self._swap_block_table(self.cross_block_tables[request_id],
self.cpu_allocator,
self.gpu_allocator,
mapping)
self.cross_block_tables[request_id] = self._swap_block_table(
self.cross_block_tables[request_id], self.cpu_allocator,
self.gpu_allocator, mapping)
return [(cpu_block.block_number, gpu_block.block_number)
for cpu_block, gpu_block in mapping.items()]
@ -576,18 +568,14 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# dict is efficient in lookup `if gpu_block in mapping`
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
self.block_tables[seq.seq_id] = \
self._swap_block_table(self.block_tables[seq.seq_id],
self.gpu_allocator,
self.cpu_allocator,
mapping)
self.block_tables[seq.seq_id] = self._swap_block_table(
self.block_tables[seq.seq_id], self.gpu_allocator,
self.cpu_allocator, mapping)
if seq_group.is_encoder_decoder():
self.cross_block_tables[request_id] = \
self._swap_block_table(self.cross_block_tables[request_id],
self.gpu_allocator,
self.cpu_allocator,
mapping)
self.cross_block_tables[request_id] = self._swap_block_table(
self.cross_block_tables[request_id], self.gpu_allocator,
self.cpu_allocator, mapping)
return [(cpu_block.block_number, gpu_block.block_number)
for cpu_block, gpu_block in mapping.items()]

View File

@ -388,7 +388,7 @@ class Scheduler:
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
Returns:
A tuple of remaining running queue (should be always 0) after
scheduling and SchedulerRunningOutputs.
@ -655,11 +655,12 @@ class Scheduler:
assert len(waiting_seqs) == 1, (
"Waiting sequence group should have only one prompt "
"sequence.")
waiting_seq = waiting_seqs[0]
num_new_tokens = self._get_num_new_tokens(seq_group,
SequenceStatus.WAITING,
enable_chunking, budget)
if not enable_chunking:
num_prompt_tokens = waiting_seqs[0].get_len()
num_prompt_tokens = waiting_seq.get_len()
assert num_new_tokens == num_prompt_tokens
prompt_limit = self._get_prompt_limit(seq_group)
@ -667,8 +668,7 @@ class Scheduler:
logger.warning(
"Input prompt (%d tokens) is too long"
" and exceeds limit of %d", num_new_tokens, prompt_limit)
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
waiting_seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
waiting_queue.popleft()
continue
@ -731,7 +731,7 @@ class Scheduler:
def _schedule_default(self) -> SchedulerOutputs:
"""Schedule queued requests.
The current policy is designed to optimize the throughput. First,
it batches as many prefill requests as possible. And it schedules
decodes. If there's a pressure on GPU memory, decode requests can
@ -825,7 +825,7 @@ class Scheduler:
def _schedule_chunked_prefill(self):
"""Schedule queued requests.
Chunked prefill allows to chunk prefill requests, batch them together
with decode requests. This policy 1. schedule as many decoding requests
as possible. 2. schedule chunked prefill requests that are not

View File

@ -270,15 +270,24 @@ class Sequence:
return self.output_text[:-buffer_length] if truncate else (
self.output_text)
def hash_of_block(self, logical_idx: int) -> int:
# TODO This can produce incorrect hash when block size > prompt size
# Compute the number of tokens in the sequence
def hash_of_block(self, logical_idx: int) -> Optional[int]:
"""Return the hash of the block if it is full."""
# TODO: The current hashing function is O(L^2). We should optimize
# this in the future.
num_tokens = self.num_hashed_tokens_of_block(logical_idx)
hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
return hash((hashed_tokens, self.lora_int_id))
assert logical_idx < len(self.logical_token_blocks), (
f"logical_idx={logical_idx} is out of range for "
f"logical_token_blocks={len(self.logical_token_blocks)}")
block = self.logical_token_blocks[logical_idx]
if block.block_hash is not None:
return block.block_hash
if not block.is_full():
return None
num_hashed_tokens = self.num_hashed_tokens_of_block(logical_idx)
hashed_tokens = self.data.get_prefix_token_ids(num_hashed_tokens)
block_hash = hash((hashed_tokens, self.lora_int_id))
# Cache the block hash for future use.
block.block_hash = block_hash
return block_hash
def num_hashed_tokens_of_block(self, logical_idx: int):
return logical_idx * self.block_size + self.block_size
@ -614,7 +623,7 @@ class SequenceGroupMetadata:
state: Internal state tied to this sequence group.
multi_modal_data: Multi modal data.
encoder_seq_data: Optional sequence data for encoder prompt
(SequenceGroup.encoder_seq). Should be None
(SequenceGroup.encoder_seq). Should be None
unless you are working with an encoder/decoder
model.
cross_block_table: Optional cross-attention block table associated