mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Fix auto prefix bug (#3239)
This commit is contained in:
34
tests/engine/test_computed_prefix_blocks.py
Normal file
34
tests/engine/test_computed_prefix_blocks.py
Normal file
@ -0,0 +1,34 @@
|
||||
import pytest
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
def test_computed_prefix_blocks(model: str, block_size: int):
|
||||
# This test checks if we are able to run the engine to completion
|
||||
# without triggering asserts.
|
||||
# We are in a scenario where all blocks from the second request's prompt
|
||||
# are full and already computed when the second request arrives.
|
||||
prompt = (
|
||||
"You are a helpful assistant. How do I build a car from cardboard and "
|
||||
"paper clips? Is there an easy to follow video tutorial available "
|
||||
"online for free?")
|
||||
prompt2 = (
|
||||
" Please recommend to me some resources where I can learn not only to "
|
||||
"handle technical difficulties of building a car, but also "
|
||||
"decoration.")
|
||||
|
||||
engine_args = EngineArgs(model=model,
|
||||
block_size=block_size,
|
||||
enable_prefix_caching=True)
|
||||
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
engine.add_request("0", prompt + prompt2, sampling_params)
|
||||
engine.step()
|
||||
engine.add_request("1", prompt, sampling_params)
|
||||
engine.step()
|
@ -1,6 +1,6 @@
|
||||
"""A block manager that manages token blocks."""
|
||||
import enum
|
||||
from itertools import count
|
||||
from itertools import count, takewhile
|
||||
from os.path import commonprefix
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
@ -426,23 +426,29 @@ class BlockSpaceManager:
|
||||
for block in block_table:
|
||||
block.last_accessed = access_time
|
||||
|
||||
def compute_last_full_block_in_seq(self, seq: Sequence):
|
||||
def compute_full_blocks_in_seq(self, seq: Sequence):
|
||||
if seq.seq_id not in self.block_tables:
|
||||
return
|
||||
max_full_block = seq.get_len() // self.block_size - 1
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
if max_full_block == -1:
|
||||
return
|
||||
block_table[max_full_block].computed = True
|
||||
for i in reversed(range(max_full_block)):
|
||||
if block_table[i].computed:
|
||||
break
|
||||
block_table[i].computed = True
|
||||
|
||||
def get_all_block_ids_till_computed(self, seq: Sequence) -> List[int]:
|
||||
def get_all_computed_blocks(self, seq: Sequence) -> List[int]:
|
||||
if seq.seq_id not in self.block_tables:
|
||||
return []
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
for block_idx in reversed(range(len(block_table))):
|
||||
if block_table[block_idx].computed:
|
||||
return [b.block_number for b in block_table[:block_idx + 1]]
|
||||
return []
|
||||
# NOTE We exclude the last block to avoid the case where the entire
|
||||
# prompt is cached. This would cause erroneous behavior in model
|
||||
# runner.
|
||||
return [
|
||||
b.block_number
|
||||
for b in takewhile(lambda b: b.computed, block_table[:-1])
|
||||
]
|
||||
|
||||
def get_common_computed_block_ids(self,
|
||||
seq_group: SequenceGroup) -> List[int]:
|
||||
@ -451,14 +457,12 @@ class BlockSpaceManager:
|
||||
return []
|
||||
|
||||
ids_list = [
|
||||
self.get_all_block_ids_till_computed(seq)
|
||||
self.get_all_computed_blocks(seq)
|
||||
for seq in iter(seq_group.seqs_dict.values())
|
||||
]
|
||||
return commonprefix([ids for ids in ids_list if ids != []])
|
||||
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
||||
# NOTE: We only mark the last full block because with prefix caching,
|
||||
# all blocks until the marked one are guaranteed to be computed.
|
||||
if self.enable_caching:
|
||||
for seq in seq_group.seqs_dict.values():
|
||||
self.compute_last_full_block_in_seq(seq)
|
||||
self.compute_full_blocks_in_seq(seq)
|
||||
|
@ -215,6 +215,7 @@ class ModelRunner:
|
||||
slot_mapping[-1].append(slot)
|
||||
|
||||
max_prompt_len = max(subquery_lens)
|
||||
assert max_prompt_len > 0
|
||||
input_tokens = _make_tensor_with_pad(input_tokens,
|
||||
max_prompt_len,
|
||||
pad=0,
|
||||
|
Reference in New Issue
Block a user