Compare commits

...

4 Commits

Author SHA1 Message Date
d5bf492f16 Merge branch 'main' into optimize-prefix-caching-scheduling 2024-06-04 00:20:15 +00:00
8c7bab79f5 simplify code 2024-06-03 03:36:38 +00:00
1936d7bab0 format 2024-06-02 00:02:54 +00:00
996cf2de5c Fix hashing logic for non-full blocks 2024-06-02 00:01:30 +00:00
4 changed files with 47 additions and 49 deletions

View File

@ -1,5 +1,5 @@
"""Token blocks.""" """Token blocks."""
from typing import List from typing import List, Optional
from vllm.utils import Device from vllm.utils import Device
@ -25,6 +25,7 @@ class LogicalTokenBlock:
self.token_ids = [_BLANK_TOKEN_ID] * block_size self.token_ids = [_BLANK_TOKEN_ID] * block_size
self.num_tokens = 0 self.num_tokens = 0
self.block_hash: Optional[int] = None
def is_empty(self) -> bool: def is_empty(self) -> bool:
return self.num_tokens == 0 return self.num_tokens == 0

View File

@ -262,8 +262,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
self.cross_block_tables: Dict[str, BlockTable] = {} self.cross_block_tables: Dict[str, BlockTable] = {}
def _get_seq_num_required_blocks(self, seq: Sequence) -> int: def _get_seq_num_required_blocks(self, seq: Sequence) -> int:
return 0 if seq is None \ return 0 if seq is None else len(seq.logical_token_blocks)
else len(seq.logical_token_blocks)
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share # FIXME(woosuk): Here we assume that all sequences in the group share
@ -275,8 +274,8 @@ class BlockSpaceManagerV1(BlockSpaceManager):
seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) seq_group.get_seqs(status=SequenceStatus.WAITING)[0])
cross_num_required_blocks = self._get_seq_num_required_blocks( cross_num_required_blocks = self._get_seq_num_required_blocks(
seq_group.get_encoder_seq()) seq_group.get_encoder_seq())
num_required_blocks = self_num_required_blocks + \ num_required_blocks = (self_num_required_blocks +
cross_num_required_blocks cross_num_required_blocks)
if self.block_sliding_window is not None: if self.block_sliding_window is not None:
@ -293,9 +292,9 @@ class BlockSpaceManagerV1(BlockSpaceManager):
else: else:
return AllocStatus.LATER return AllocStatus.LATER
def _allocate_sequence(self, \ def _allocate_sequence(self,
seq: Sequence, \ seq: Sequence,
ref_count: int, \ ref_count: int,
is_encoder_decoder: bool = True) -> BlockTable: is_encoder_decoder: bool = True) -> BlockTable:
# Allocate new physical token blocks that will store the prompt tokens. # Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks = len(seq.logical_token_blocks) num_prompt_blocks = len(seq.logical_token_blocks)
@ -328,10 +327,8 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# NOTE: Here we assume that all sequences in the group have the same # NOTE: Here we assume that all sequences in the group have the same
# decoder prompt. # decoder prompt.
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
block_table: BlockTable = \ block_table: BlockTable = self._allocate_sequence(
self._allocate_sequence(seq, seq, seq_group.num_seqs(), is_encoder_decoder)
seq_group.num_seqs(),
is_encoder_decoder)
# Assign the self-attention block tables for each sequence. # Assign the self-attention block tables for each sequence.
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
@ -368,6 +365,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# Compute a new hash for the block so that it can be shared by other # Compute a new hash for the block so that it can be shared by other
# Sequences # Sequences
new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
assert new_hash is not None, "Last block is not full."
# if new_hash is already in the cached table, then free last_block # if new_hash is already in the cached table, then free last_block
# and return the cached version # and return the cached version
@ -406,9 +404,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# content hash. # content hash.
if not self.enable_caching: if not self.enable_caching:
return self.gpu_allocator.allocate() return self.gpu_allocator.allocate()
block_hash: Optional[int] = None block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
if (self._is_last_block_full(seq)):
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
num_hashed_tokens = seq.num_hashed_tokens_of_block( num_hashed_tokens = seq.num_hashed_tokens_of_block(
len(seq.logical_token_blocks) - 1) len(seq.logical_token_blocks) - 1)
@ -549,18 +545,14 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# dict is efficient in lookup `if cpu_block in mapping` # dict is efficient in lookup `if cpu_block in mapping`
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
self.block_tables[seq.seq_id] = \ self.block_tables[seq.seq_id] = self._swap_block_table(
self._swap_block_table(self.block_tables[seq.seq_id], self.block_tables[seq.seq_id], self.cpu_allocator,
self.cpu_allocator, self.gpu_allocator, mapping)
self.gpu_allocator,
mapping)
if seq_group.is_encoder_decoder(): if seq_group.is_encoder_decoder():
self.cross_block_tables[request_id] = \ self.cross_block_tables[request_id] = self._swap_block_table(
self._swap_block_table(self.cross_block_tables[request_id], self.cross_block_tables[request_id], self.cpu_allocator,
self.cpu_allocator, self.gpu_allocator, mapping)
self.gpu_allocator,
mapping)
return [(cpu_block.block_number, gpu_block.block_number) return [(cpu_block.block_number, gpu_block.block_number)
for cpu_block, gpu_block in mapping.items()] for cpu_block, gpu_block in mapping.items()]
@ -576,18 +568,14 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# dict is efficient in lookup `if gpu_block in mapping` # dict is efficient in lookup `if gpu_block in mapping`
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
self.block_tables[seq.seq_id] = \ self.block_tables[seq.seq_id] = self._swap_block_table(
self._swap_block_table(self.block_tables[seq.seq_id], self.block_tables[seq.seq_id], self.gpu_allocator,
self.gpu_allocator, self.cpu_allocator, mapping)
self.cpu_allocator,
mapping)
if seq_group.is_encoder_decoder(): if seq_group.is_encoder_decoder():
self.cross_block_tables[request_id] = \ self.cross_block_tables[request_id] = self._swap_block_table(
self._swap_block_table(self.cross_block_tables[request_id], self.cross_block_tables[request_id], self.gpu_allocator,
self.gpu_allocator, self.cpu_allocator, mapping)
self.cpu_allocator,
mapping)
return [(cpu_block.block_number, gpu_block.block_number) return [(cpu_block.block_number, gpu_block.block_number)
for cpu_block, gpu_block in mapping.items()] for cpu_block, gpu_block in mapping.items()]

View File

@ -388,7 +388,7 @@ class Scheduler:
chunked number of tokens are scheduled if chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule `budget.num_batched_tokens` has not enough capacity to schedule
all tokens. all tokens.
Returns: Returns:
A tuple of remaining running queue (should be always 0) after A tuple of remaining running queue (should be always 0) after
scheduling and SchedulerRunningOutputs. scheduling and SchedulerRunningOutputs.
@ -655,11 +655,12 @@ class Scheduler:
assert len(waiting_seqs) == 1, ( assert len(waiting_seqs) == 1, (
"Waiting sequence group should have only one prompt " "Waiting sequence group should have only one prompt "
"sequence.") "sequence.")
waiting_seq = waiting_seqs[0]
num_new_tokens = self._get_num_new_tokens(seq_group, num_new_tokens = self._get_num_new_tokens(seq_group,
SequenceStatus.WAITING, SequenceStatus.WAITING,
enable_chunking, budget) enable_chunking, budget)
if not enable_chunking: if not enable_chunking:
num_prompt_tokens = waiting_seqs[0].get_len() num_prompt_tokens = waiting_seq.get_len()
assert num_new_tokens == num_prompt_tokens assert num_new_tokens == num_prompt_tokens
prompt_limit = self._get_prompt_limit(seq_group) prompt_limit = self._get_prompt_limit(seq_group)
@ -667,8 +668,7 @@ class Scheduler:
logger.warning( logger.warning(
"Input prompt (%d tokens) is too long" "Input prompt (%d tokens) is too long"
" and exceeds limit of %d", num_new_tokens, prompt_limit) " and exceeds limit of %d", num_new_tokens, prompt_limit)
for seq in waiting_seqs: waiting_seq.status = SequenceStatus.FINISHED_IGNORED
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group) ignored_seq_groups.append(seq_group)
waiting_queue.popleft() waiting_queue.popleft()
continue continue
@ -731,7 +731,7 @@ class Scheduler:
def _schedule_default(self) -> SchedulerOutputs: def _schedule_default(self) -> SchedulerOutputs:
"""Schedule queued requests. """Schedule queued requests.
The current policy is designed to optimize the throughput. First, The current policy is designed to optimize the throughput. First,
it batches as many prefill requests as possible. And it schedules it batches as many prefill requests as possible. And it schedules
decodes. If there's a pressure on GPU memory, decode requests can decodes. If there's a pressure on GPU memory, decode requests can
@ -825,7 +825,7 @@ class Scheduler:
def _schedule_chunked_prefill(self): def _schedule_chunked_prefill(self):
"""Schedule queued requests. """Schedule queued requests.
Chunked prefill allows to chunk prefill requests, batch them together Chunked prefill allows to chunk prefill requests, batch them together
with decode requests. This policy 1. schedule as many decoding requests with decode requests. This policy 1. schedule as many decoding requests
as possible. 2. schedule chunked prefill requests that are not as possible. 2. schedule chunked prefill requests that are not

View File

@ -270,15 +270,24 @@ class Sequence:
return self.output_text[:-buffer_length] if truncate else ( return self.output_text[:-buffer_length] if truncate else (
self.output_text) self.output_text)
def hash_of_block(self, logical_idx: int) -> int: def hash_of_block(self, logical_idx: int) -> Optional[int]:
# TODO This can produce incorrect hash when block size > prompt size """Return the hash of the block if it is full."""
# Compute the number of tokens in the sequence
# TODO: The current hashing function is O(L^2). We should optimize # TODO: The current hashing function is O(L^2). We should optimize
# this in the future. # this in the future.
num_tokens = self.num_hashed_tokens_of_block(logical_idx) assert logical_idx < len(self.logical_token_blocks), (
hashed_tokens = self.data.get_prefix_token_ids(num_tokens) f"logical_idx={logical_idx} is out of range for "
return hash((hashed_tokens, self.lora_int_id)) f"logical_token_blocks={len(self.logical_token_blocks)}")
block = self.logical_token_blocks[logical_idx]
if block.block_hash is not None:
return block.block_hash
if not block.is_full():
return None
num_hashed_tokens = self.num_hashed_tokens_of_block(logical_idx)
hashed_tokens = self.data.get_prefix_token_ids(num_hashed_tokens)
block_hash = hash((hashed_tokens, self.lora_int_id))
# Cache the block hash for future use.
block.block_hash = block_hash
return block_hash
def num_hashed_tokens_of_block(self, logical_idx: int): def num_hashed_tokens_of_block(self, logical_idx: int):
return logical_idx * self.block_size + self.block_size return logical_idx * self.block_size + self.block_size
@ -614,7 +623,7 @@ class SequenceGroupMetadata:
state: Internal state tied to this sequence group. state: Internal state tied to this sequence group.
multi_modal_data: Multi modal data. multi_modal_data: Multi modal data.
encoder_seq_data: Optional sequence data for encoder prompt encoder_seq_data: Optional sequence data for encoder prompt
(SequenceGroup.encoder_seq). Should be None (SequenceGroup.encoder_seq). Should be None
unless you are working with an encoder/decoder unless you are working with an encoder/decoder
model. model.
cross_block_table: Optional cross-attention block table associated cross_block_table: Optional cross-attention block table associated