Compare commits

...

2 Commits

Author SHA1 Message Date
1936d7bab0 format 2024-06-02 00:02:54 +00:00
996cf2de5c Fix hashing logic for non-full blocks 2024-06-02 00:01:30 +00:00
3 changed files with 41 additions and 43 deletions

View File

@ -1,5 +1,5 @@
"""Token blocks.""" """Token blocks."""
from typing import List from typing import List, Optional
from vllm.utils import Device from vllm.utils import Device
@ -25,6 +25,7 @@ class LogicalTokenBlock:
self.token_ids = [_BLANK_TOKEN_ID] * block_size self.token_ids = [_BLANK_TOKEN_ID] * block_size
self.num_tokens = 0 self.num_tokens = 0
self.block_hash: Optional[int] = None
def is_empty(self) -> bool: def is_empty(self) -> bool:
return self.num_tokens == 0 return self.num_tokens == 0

View File

@ -262,8 +262,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
self.cross_block_tables: Dict[str, BlockTable] = {} self.cross_block_tables: Dict[str, BlockTable] = {}
def _get_seq_num_required_blocks(self, seq: Sequence) -> int: def _get_seq_num_required_blocks(self, seq: Sequence) -> int:
return 0 if seq is None \ return 0 if seq is None else len(seq.logical_token_blocks)
else len(seq.logical_token_blocks)
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share # FIXME(woosuk): Here we assume that all sequences in the group share
@ -275,8 +274,8 @@ class BlockSpaceManagerV1(BlockSpaceManager):
seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) seq_group.get_seqs(status=SequenceStatus.WAITING)[0])
cross_num_required_blocks = self._get_seq_num_required_blocks( cross_num_required_blocks = self._get_seq_num_required_blocks(
seq_group.get_encoder_seq()) seq_group.get_encoder_seq())
num_required_blocks = self_num_required_blocks + \ num_required_blocks = (self_num_required_blocks +
cross_num_required_blocks cross_num_required_blocks)
if self.block_sliding_window is not None: if self.block_sliding_window is not None:
@ -293,9 +292,9 @@ class BlockSpaceManagerV1(BlockSpaceManager):
else: else:
return AllocStatus.LATER return AllocStatus.LATER
def _allocate_sequence(self, \ def _allocate_sequence(self,
seq: Sequence, \ seq: Sequence,
ref_count: int, \ ref_count: int,
is_encoder_decoder: bool = True) -> BlockTable: is_encoder_decoder: bool = True) -> BlockTable:
# Allocate new physical token blocks that will store the prompt tokens. # Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks = len(seq.logical_token_blocks) num_prompt_blocks = len(seq.logical_token_blocks)
@ -328,10 +327,8 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# NOTE: Here we assume that all sequences in the group have the same # NOTE: Here we assume that all sequences in the group have the same
# decoder prompt. # decoder prompt.
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
block_table: BlockTable = \ block_table: BlockTable = self._allocate_sequence(
self._allocate_sequence(seq, seq, seq_group.num_seqs(), is_encoder_decoder)
seq_group.num_seqs(),
is_encoder_decoder)
# Assign the self-attention block tables for each sequence. # Assign the self-attention block tables for each sequence.
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
@ -368,6 +365,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# Compute a new hash for the block so that it can be shared by other # Compute a new hash for the block so that it can be shared by other
# Sequences # Sequences
new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
assert new_hash is not None, "Last block is not full."
# if new_hash is already in the cached table, then free last_block # if new_hash is already in the cached table, then free last_block
# and return the cached version # and return the cached version
@ -406,9 +404,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# content hash. # content hash.
if not self.enable_caching: if not self.enable_caching:
return self.gpu_allocator.allocate() return self.gpu_allocator.allocate()
block_hash: Optional[int] = None block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
if (self._is_last_block_full(seq)):
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
num_hashed_tokens = seq.num_hashed_tokens_of_block( num_hashed_tokens = seq.num_hashed_tokens_of_block(
len(seq.logical_token_blocks) - 1) len(seq.logical_token_blocks) - 1)
@ -553,18 +549,14 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# dict is efficient in lookup `if cpu_block in mapping` # dict is efficient in lookup `if cpu_block in mapping`
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
self.block_tables[seq.seq_id] = \ self.block_tables[seq.seq_id] = self._swap_block_table(
self._swap_block_table(self.block_tables[seq.seq_id], self.block_tables[seq.seq_id], self.cpu_allocator,
self.cpu_allocator, self.gpu_allocator, mapping)
self.gpu_allocator,
mapping)
if seq_group.is_encoder_decoder(): if seq_group.is_encoder_decoder():
self.cross_block_tables[request_id] = \ self.cross_block_tables[request_id] = self._swap_block_table(
self._swap_block_table(self.cross_block_tables[request_id], self.cross_block_tables[request_id], self.cpu_allocator,
self.cpu_allocator, self.gpu_allocator, mapping)
self.gpu_allocator,
mapping)
return [(cpu_block.block_number, gpu_block.block_number) return [(cpu_block.block_number, gpu_block.block_number)
for cpu_block, gpu_block in mapping.items()] for cpu_block, gpu_block in mapping.items()]
@ -580,18 +572,14 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# dict is efficient in lookup `if gpu_block in mapping` # dict is efficient in lookup `if gpu_block in mapping`
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
self.block_tables[seq.seq_id] = \ self.block_tables[seq.seq_id] = self._swap_block_table(
self._swap_block_table(self.block_tables[seq.seq_id], self.block_tables[seq.seq_id], self.gpu_allocator,
self.gpu_allocator, self.cpu_allocator, mapping)
self.cpu_allocator,
mapping)
if seq_group.is_encoder_decoder(): if seq_group.is_encoder_decoder():
self.cross_block_tables[request_id] = \ self.cross_block_tables[request_id] = self._swap_block_table(
self._swap_block_table(self.cross_block_tables[request_id], self.cross_block_tables[request_id], self.gpu_allocator,
self.gpu_allocator, self.cpu_allocator, mapping)
self.cpu_allocator,
mapping)
return [(cpu_block.block_number, gpu_block.block_number) return [(cpu_block.block_number, gpu_block.block_number)
for cpu_block, gpu_block in mapping.items()] for cpu_block, gpu_block in mapping.items()]

View File

@ -269,15 +269,24 @@ class Sequence:
return self.output_text[:-buffer_length] if truncate else ( return self.output_text[:-buffer_length] if truncate else (
self.output_text) self.output_text)
def hash_of_block(self, logical_idx: int) -> int: def hash_of_block(self, logical_idx: int) -> Optional[int]:
# TODO This can produce incorrect hash when block size > prompt size """Return the hash of the block if it is full."""
# Compute the number of tokens in the sequence
# TODO: The current hashing function is O(L^2). We should optimize # TODO: The current hashing function is O(L^2). We should optimize
# this in the future. # this in the future.
num_tokens = self.num_hashed_tokens_of_block(logical_idx) assert logical_idx < len(self.logical_token_blocks), (
hashed_tokens = self.data.get_prefix_token_ids(num_tokens) f"logical_idx={logical_idx} is out of range for "
return hash((hashed_tokens, self.lora_int_id)) f"logical_token_blocks={len(self.logical_token_blocks)}")
block = self.logical_token_blocks[logical_idx]
if block.block_hash is not None:
return block.block_hash
if not block.is_full():
return None
num_hashed_tokens = self.num_hashed_tokens_of_block(logical_idx)
hashed_tokens = self.data.get_prefix_token_ids(num_hashed_tokens)
block_hash = hash((hashed_tokens, self.lora_int_id))
# Cache the block hash for future use.
block.block_hash = block_hash
return block_hash
def num_hashed_tokens_of_block(self, logical_idx: int): def num_hashed_tokens_of_block(self, logical_idx: int):
return logical_idx * self.block_size + self.block_size return logical_idx * self.block_size + self.block_size
@ -632,7 +641,7 @@ class SequenceGroupMetadata:
state: Internal state tied to this sequence group. state: Internal state tied to this sequence group.
multi_modal_data: Multi modal data. multi_modal_data: Multi modal data.
encoder_seq_data: Optional sequence data for encoder prompt encoder_seq_data: Optional sequence data for encoder prompt
(SequenceGroup.encoder_seq). Should be None (SequenceGroup.encoder_seq). Should be None
unless you are working with an encoder/decoder unless you are working with an encoder/decoder
model. model.
cross_block_table: Optional cross-attention block table associated cross_block_table: Optional cross-attention block table associated