Fix auto prefix bug (#3239)

This commit is contained in:
ElizaWszola
2024-03-08 01:37:28 +01:00
committed by GitHub
parent 8cbba4622c
commit b35cc93420
3 changed files with 51 additions and 12 deletions

View File

@ -0,0 +1,34 @@
import pytest
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.sampling_params import SamplingParams
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("block_size", [16])
def test_computed_prefix_blocks(model: str, block_size: int):
# This test checks if we are able to run the engine to completion
# without triggering asserts.
# We are in a scenario where all blocks from the second request's prompt
# are full and already computed when the second request arrives.
prompt = (
"You are a helpful assistant. How do I build a car from cardboard and "
"paper clips? Is there an easy to follow video tutorial available "
"online for free?")
prompt2 = (
" Please recommend to me some resources where I can learn not only to "
"handle technical difficulties of building a car, but also "
"decoration.")
engine_args = EngineArgs(model=model,
block_size=block_size,
enable_prefix_caching=True)
engine = LLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams()
engine.add_request("0", prompt + prompt2, sampling_params)
engine.step()
engine.add_request("1", prompt, sampling_params)
engine.step()

View File

@ -1,6 +1,6 @@
"""A block manager that manages token blocks."""
import enum
from itertools import count
from itertools import count, takewhile
from os.path import commonprefix
from typing import Dict, List, Optional, Set, Tuple
@ -426,23 +426,29 @@ class BlockSpaceManager:
for block in block_table:
block.last_accessed = access_time
def compute_last_full_block_in_seq(self, seq: Sequence):
def compute_full_blocks_in_seq(self, seq: Sequence):
if seq.seq_id not in self.block_tables:
return
max_full_block = seq.get_len() // self.block_size - 1
block_table = self.block_tables[seq.seq_id]
if max_full_block == -1:
return
block_table[max_full_block].computed = True
for i in reversed(range(max_full_block)):
if block_table[i].computed:
break
block_table[i].computed = True
def get_all_block_ids_till_computed(self, seq: Sequence) -> List[int]:
def get_all_computed_blocks(self, seq: Sequence) -> List[int]:
if seq.seq_id not in self.block_tables:
return []
block_table = self.block_tables[seq.seq_id]
for block_idx in reversed(range(len(block_table))):
if block_table[block_idx].computed:
return [b.block_number for b in block_table[:block_idx + 1]]
return []
# NOTE We exclude the last block to avoid the case where the entire
# prompt is cached. This would cause erroneous behavior in model
# runner.
return [
b.block_number
for b in takewhile(lambda b: b.computed, block_table[:-1])
]
def get_common_computed_block_ids(self,
seq_group: SequenceGroup) -> List[int]:
@ -451,14 +457,12 @@ class BlockSpaceManager:
return []
ids_list = [
self.get_all_block_ids_till_computed(seq)
self.get_all_computed_blocks(seq)
for seq in iter(seq_group.seqs_dict.values())
]
return commonprefix([ids for ids in ids_list if ids != []])
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
# NOTE: We only mark the last full block because with prefix caching,
# all blocks until the marked one are guaranteed to be computed.
if self.enable_caching:
for seq in seq_group.seqs_dict.values():
self.compute_last_full_block_in_seq(seq)
self.compute_full_blocks_in_seq(seq)

View File

@ -215,6 +215,7 @@ class ModelRunner:
slot_mapping[-1].append(slot)
max_prompt_len = max(subquery_lens)
assert max_prompt_len > 0
input_tokens = _make_tensor_with_pad(input_tokens,
max_prompt_len,
pad=0,