[Optimization] use a pool to reuse LogicalTokenBlock.token_ids (#5584)

This commit is contained in:
youkaichao
2024-06-17 15:08:05 -07:00
committed by GitHub
parent 1b44aaf4e3
commit e441bad674

View File

@ -1,5 +1,7 @@
"""Token blocks."""
from typing import List
import weakref
from collections import defaultdict
from typing import Dict, List
from vllm.utils import Device
@ -7,6 +9,35 @@ _BLANK_TOKEN_ID = -1
DEFAULT_LAST_ACCESSED_TIME = -1
TokensBlock = List[int]
class BlockPool:
"""A pool of physical blocks.
When requests come, we create a lot of logical blocks;
when requests are done, we destroy a lot of logical blocks.
It turns out that creating and destroying logical blocks can be expensive,
especially for the `token_ids` field, which is a list of integers.
To avoid this overhead, we use a pool to manage the logical blocks.
When an old request is done and a new request comes, we can reuse the
logical blocks from the old request to feed the new request.
"""
def __init__(self) -> None:
# block size to list of token blocks
self.pool: Dict[int, List[TokensBlock]] = defaultdict(list)
def alloc_block(self, block_size: int) -> TokensBlock:
if block_size in self.pool and self.pool[block_size]:
return self.pool[block_size].pop()
return [_BLANK_TOKEN_ID] * block_size
def del_block(self, block: TokensBlock) -> None:
self.pool[len(block)].append(block)
_BLOCK_POOL = BlockPool()
class LogicalTokenBlock:
"""A block that stores a contiguous chunk of tokens from left to right.
@ -23,7 +54,13 @@ class LogicalTokenBlock:
self.block_number = block_number
self.block_size = block_size
self.token_ids = [_BLANK_TOKEN_ID] * block_size
self.token_ids = _BLOCK_POOL.alloc_block(block_size)
# this finalizer is used to return the block to the pool when the object is deleted # noqa
# NOTE: don't use __del__ because it cannot guarantee the order of finalization, # noqa
# i.e. `self.token_ids` may be deleted before `self`, and we lose
# the opportunity to return the block to the pool
self._finalizer = weakref.finalize(self, _BLOCK_POOL.del_block,
self.token_ids)
self.num_tokens = 0
def is_empty(self) -> bool: