mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Core] Ignore infeasible swap requests. (#4557)
This commit is contained in:
@ -7,6 +7,7 @@ pytest tests/basic_correctness/test_preemption.py`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
|
||||
ENABLE_ARTIFICIAL_PREEMPT)
|
||||
|
||||
@ -136,3 +137,87 @@ def test_swap(
|
||||
assert hf_output_ids[j] == vllm_output_ids[j], (
|
||||
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
|
||||
f"vLLM: {vllm_output_ids}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [96])
|
||||
@pytest.mark.parametrize("beam_width", [4])
|
||||
def test_swap_infeasible(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
beam_width: int,
|
||||
) -> None:
|
||||
"""Verify infeasible swap request will be ignored."""
|
||||
BLOCK_SIZE = 16
|
||||
prefill_blocks = 2
|
||||
decode_blocks = max_tokens // BLOCK_SIZE
|
||||
example_prompts = example_prompts[:1]
|
||||
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
swap_space=10,
|
||||
block_size=BLOCK_SIZE,
|
||||
# Since beam search have more than 1 sequence, prefill + decode blocks
|
||||
# are not enough to finish.
|
||||
num_gpu_blocks_override=prefill_blocks + decode_blocks,
|
||||
max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE,
|
||||
)
|
||||
sampling_params = SamplingParams(n=beam_width,
|
||||
use_beam_search=True,
|
||||
temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
ignore_eos=True)
|
||||
req_outputs = vllm_model.model.generate(
|
||||
example_prompts,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
|
||||
ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||
del vllm_model
|
||||
# Verify the request is ignored and not hang.
|
||||
assert req_outputs[0].outputs[0].finish_reason == "length"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [96])
|
||||
def test_preemption_infeasible(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
"""Verify infeasible preemption request will be ignored."""
|
||||
BLOCK_SIZE = 16
|
||||
prefill_blocks = 2
|
||||
decode_blocks = max_tokens // BLOCK_SIZE
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
block_size=BLOCK_SIZE,
|
||||
# Not enough gpu blocks to complete a single sequence.
|
||||
# preemption should happen, and the sequence should be
|
||||
# ignored instead of hanging forever.
|
||||
num_gpu_blocks_override=prefill_blocks + decode_blocks // 2,
|
||||
max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE),
|
||||
)
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True)
|
||||
req_outputs = vllm_model.model.generate(
|
||||
example_prompts,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
|
||||
ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||
del vllm_model
|
||||
# Verify the request is ignored and not hang.
|
||||
for req_output in req_outputs:
|
||||
outputs = req_output.outputs
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0].finish_reason == "length"
|
||||
|
@ -224,7 +224,7 @@ def test_swap():
|
||||
|
||||
# Swap seq group from CPU -> GPU.
|
||||
cpu_blocks = block_manager.get_block_table(prompt)
|
||||
assert block_manager.can_swap_in(seq_group)
|
||||
assert block_manager.can_swap_in(seq_group) == AllocStatus.OK
|
||||
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
mapping = block_manager.swap_in(seq_group)
|
||||
|
@ -4,6 +4,7 @@ from unittest.mock import MagicMock
|
||||
import pytest # noqa
|
||||
|
||||
from vllm.config import CacheConfig, SchedulerConfig
|
||||
from vllm.core.interfaces import AllocStatus
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.sequence import Logprob, SequenceGroup
|
||||
|
||||
@ -410,7 +411,7 @@ def test_running_prefill_prioritized_over_swap():
|
||||
|
||||
# Add 1 more task. Swap is not possible, so prefill is running.
|
||||
scheduler.block_manager.can_swap_in = MagicMock()
|
||||
scheduler.block_manager.can_swap_in.return_value = False
|
||||
scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
|
||||
|
||||
_, seq_group2 = create_dummy_prompt("2", prompt_length=60)
|
||||
scheduler.add_seq_group(seq_group2)
|
||||
@ -423,7 +424,7 @@ def test_running_prefill_prioritized_over_swap():
|
||||
assert out.scheduled_seq_groups[0].seq_group == seq_group2
|
||||
|
||||
# Now although swap is possible, running prefill is prioritized.
|
||||
scheduler.block_manager.can_swap_in.return_value = True
|
||||
scheduler.block_manager.can_swap_in.return_value = AllocStatus.OK
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
# 3 decodes. It is swapped in.
|
||||
|
@ -791,7 +791,7 @@ def test_schedule_swapped_cannot_swap_in():
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.can_swap_in = MagicMock()
|
||||
scheduler.block_manager.can_swap_in.return_value = False
|
||||
scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
|
||||
# Since we cannot swap in, none of the requests are swapped in.
|
||||
budget = create_token_budget()
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
@ -803,6 +803,34 @@ def test_schedule_swapped_cannot_swap_in():
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
|
||||
|
||||
def test_infeasible_swap():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = {}
|
||||
for _ in range(2):
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
|
||||
# The last request should be swapped out.
|
||||
scheduler.block_manager.can_swap_in = MagicMock()
|
||||
scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER
|
||||
# Since we cannot swap in, none of the requests are swapped in.
|
||||
budget = create_token_budget()
|
||||
remaining_swapped, output = scheduler._schedule_swapped(
|
||||
swapped, budget, curr_loras, policy)
|
||||
assert len(remaining_swapped) == 0
|
||||
assert len(output.infeasible_seq_groups) == 2
|
||||
assert budget.num_batched_tokens == 0
|
||||
assert budget.num_curr_seqs == 0
|
||||
assert len(output.decode_seq_groups) == 0
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
|
||||
|
||||
def test_schedule_swapped_blocks_to_copy():
|
||||
scheduler = initialize_scheduler()
|
||||
swapped = deque()
|
||||
|
@ -110,9 +110,8 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
||||
for block_id in allocator.all_block_ids:
|
||||
self._block_ids_to_allocator[block_id] = allocator
|
||||
|
||||
def allocate_mutable(self,
|
||||
prev_block: Optional[Block],
|
||||
device: Optional[Device] = None) -> Block:
|
||||
def allocate_mutable(self, prev_block: Optional[Block],
|
||||
device: Device) -> Block:
|
||||
"""Allocates a new mutable block on the specified device.
|
||||
|
||||
Args:
|
||||
@ -123,13 +122,10 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
||||
Returns:
|
||||
Block: The newly allocated mutable block.
|
||||
"""
|
||||
assert device is not None
|
||||
return self._allocators[device].allocate_mutable(prev_block)
|
||||
|
||||
def allocate_immutable(self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
device: Optional[Device] = None) -> Block:
|
||||
def allocate_immutable(self, prev_block: Optional[Block],
|
||||
token_ids: List[int], device: Device) -> Block:
|
||||
"""Allocates a new immutable block with the provided token IDs on the
|
||||
specified device.
|
||||
|
||||
@ -144,7 +140,6 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
||||
Block: The newly allocated immutable block containing the provided
|
||||
token IDs.
|
||||
"""
|
||||
assert device is not None
|
||||
return self._allocators[device].allocate_immutable(
|
||||
prev_block, token_ids)
|
||||
|
||||
@ -175,7 +170,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
||||
allocator = self._block_ids_to_allocator[block_id]
|
||||
return allocator.fork(last_block)
|
||||
|
||||
def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
|
||||
def get_num_free_blocks(self, device: Device) -> int:
|
||||
"""Returns the number of free blocks available on the specified device.
|
||||
|
||||
Args:
|
||||
@ -185,9 +180,11 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
||||
Returns:
|
||||
int: The number of free blocks available on the specified device.
|
||||
"""
|
||||
assert device is not None
|
||||
return self._allocators[device].get_num_free_blocks()
|
||||
|
||||
def get_num_total_blocks(self, device: Device) -> int:
|
||||
return self._allocators[device].get_num_total_blocks()
|
||||
|
||||
def clear_copy_on_writes(self) -> Dict[int, List[int]]:
|
||||
"""Clears the copy-on-write (CoW) state and returns the mapping of
|
||||
source to destination block IDs.
|
||||
|
@ -108,6 +108,10 @@ class BlockAllocator(ABC):
|
||||
def fork(self, last_block: Block) -> List[Block]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_total_blocks(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_free_blocks(self) -> int:
|
||||
pass
|
||||
@ -152,20 +156,21 @@ class BlockAllocator(ABC):
|
||||
class DeviceAwareBlockAllocator(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def allocate_mutable(self,
|
||||
prev_block: Optional[Block],
|
||||
device: Optional[Device] = None) -> Block:
|
||||
def allocate_mutable(self, prev_block: Optional[Block],
|
||||
device: Device) -> Block:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def allocate_immutable(self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
device: Optional[Device] = None) -> Block:
|
||||
def allocate_immutable(self, prev_block: Optional[Block],
|
||||
token_ids: List[int], device: Device) -> Block:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
|
||||
def get_num_free_blocks(self, device: Device) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_total_blocks(self, device: Device) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
@ -133,10 +133,12 @@ class NaiveBlockAllocator(BlockAllocator):
|
||||
|
||||
return forked_blocks
|
||||
|
||||
def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
|
||||
assert device is None
|
||||
def get_num_free_blocks(self) -> int:
|
||||
return len(self._free_block_indices)
|
||||
|
||||
def get_num_total_blocks(self) -> int:
|
||||
return len(self._all_block_indices)
|
||||
|
||||
def _allocate_new_block_id(self) -> BlockId:
|
||||
if not self._free_block_indices:
|
||||
raise BlockAllocator.NoFreeBlocksError()
|
||||
|
@ -285,6 +285,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
||||
return self._hashless_allocator.get_num_free_blocks(
|
||||
) + self.evictor.num_blocks
|
||||
|
||||
def get_num_total_blocks(self) -> int:
|
||||
return self._hashless_allocator.get_num_total_blocks()
|
||||
|
||||
@property
|
||||
def all_block_ids(self) -> FrozenSet[int]:
|
||||
return self._hashless_allocator.all_block_ids
|
||||
|
@ -47,6 +47,10 @@ class BlockAllocatorBase(ABC):
|
||||
def get_num_free_blocks(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_total_blocks(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def contains_block(self, block_hash: int) -> bool:
|
||||
pass
|
||||
@ -131,6 +135,9 @@ class CachedBlockAllocator(BlockAllocatorBase):
|
||||
return (self.num_blocks - self.current_num_blocks +
|
||||
self.evictor.num_blocks)
|
||||
|
||||
def get_num_total_blocks(self) -> int:
|
||||
return self.num_blocks
|
||||
|
||||
def contains_block(self, block_hash: int) -> bool:
|
||||
return block_hash in self.cached_blocks or block_hash in self.evictor
|
||||
|
||||
@ -190,6 +197,9 @@ class UncachedBlockAllocator(BlockAllocatorBase):
|
||||
def get_num_free_blocks(self) -> int:
|
||||
return len(self.free_blocks)
|
||||
|
||||
def get_num_total_blocks(self) -> int:
|
||||
return self.num_blocks
|
||||
|
||||
def contains_block(self, block_hash: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Invalid codepath for uncached block allocator.")
|
||||
@ -444,7 +454,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
|
||||
def can_swap_in(self,
|
||||
seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int = 0) -> bool:
|
||||
num_lookahead_slots: int = 0) -> AllocStatus:
|
||||
assert (num_lookahead_slots == 0
|
||||
), "BlockSpaceManagerV1 does not support lookahead allocation"
|
||||
blocks = self._get_physical_blocks(seq_group)
|
||||
@ -454,7 +464,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
# at least one free block right after the swap-in.
|
||||
# NOTE: This should match the logic in can_append_slot().
|
||||
num_required_blocks = len(blocks) + num_swapped_seqs
|
||||
return num_free_blocks - num_required_blocks >= self.watermark_blocks
|
||||
if self.gpu_allocator.get_num_total_blocks() < num_required_blocks:
|
||||
return AllocStatus.NEVER
|
||||
elif num_free_blocks - num_required_blocks >= self.watermark_blocks:
|
||||
return AllocStatus.OK
|
||||
else:
|
||||
return AllocStatus.LATER
|
||||
|
||||
def swap_in(self,
|
||||
seq_group: SequenceGroup,
|
||||
|
@ -238,8 +238,8 @@ class BlockSpaceManagerV2(BlockSpaceManager):
|
||||
self.block_tables[child_seq.seq_id] = src_block_table.fork()
|
||||
|
||||
def can_swap_in(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> bool:
|
||||
return False
|
||||
num_lookahead_slots: int) -> AllocStatus:
|
||||
return AllocStatus.LATER
|
||||
|
||||
def swap_in(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> Dict[int, int]:
|
||||
|
@ -63,7 +63,7 @@ class BlockSpaceManager(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def can_swap_in(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> bool:
|
||||
num_lookahead_slots: int) -> AllocStatus:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
@ -210,6 +210,8 @@ class SchedulerSwappedInOutputs:
|
||||
blocks_to_copy: Dict[int, List[int]]
|
||||
# The number of slots for lookahead decoding.
|
||||
num_lookahead_slots: int
|
||||
# Infeasible sequence groups.
|
||||
infeasible_seq_groups: List[SequenceGroup]
|
||||
|
||||
@classmethod
|
||||
def create_empty(cls) -> "SchedulerSwappedInOutputs":
|
||||
@ -219,6 +221,7 @@ class SchedulerSwappedInOutputs:
|
||||
blocks_to_swap_in={},
|
||||
blocks_to_copy={},
|
||||
num_lookahead_slots=0,
|
||||
infeasible_seq_groups=[],
|
||||
)
|
||||
|
||||
|
||||
@ -511,14 +514,26 @@ class Scheduler:
|
||||
prefill_seq_groups: List[ScheduledSequenceGroup] = []
|
||||
now = time.time()
|
||||
swapped_queue = policy.sort_by_priority(now, swapped_queue)
|
||||
infeasible_seq_groups: List[SequenceGroup] = []
|
||||
|
||||
leftover_swapped: Deque[SequenceGroup] = deque()
|
||||
while swapped_queue:
|
||||
seq_group = swapped_queue[0]
|
||||
|
||||
# If the sequence group cannot be swapped in, stop.
|
||||
if not self.block_manager.can_swap_in(seq_group):
|
||||
alloc_status = self.block_manager.can_swap_in(seq_group)
|
||||
if alloc_status == AllocStatus.LATER:
|
||||
break
|
||||
elif alloc_status == AllocStatus.NEVER:
|
||||
logger.warning(
|
||||
"Failing the request %s because there's not enough kv "
|
||||
"cache blocks to run the entire sequence.",
|
||||
seq_group.request_id)
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||
infeasible_seq_groups.append(seq_group)
|
||||
swapped_queue.popleft()
|
||||
continue
|
||||
|
||||
lora_int_id = 0
|
||||
if self.lora_enabled:
|
||||
@ -569,7 +584,9 @@ class Scheduler:
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
num_lookahead_slots=self._get_num_lookahead_slots(
|
||||
is_prefill=False))
|
||||
is_prefill=False),
|
||||
infeasible_seq_groups=infeasible_seq_groups,
|
||||
)
|
||||
|
||||
def _schedule_prefills(
|
||||
self,
|
||||
@ -777,7 +794,8 @@ class Scheduler:
|
||||
blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
|
||||
blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy,
|
||||
swapped_in.blocks_to_copy),
|
||||
ignored_seq_groups=prefills.ignored_seq_groups,
|
||||
ignored_seq_groups=prefills.ignored_seq_groups +
|
||||
swapped_in.infeasible_seq_groups,
|
||||
num_lookahead_slots=running_scheduled.num_lookahead_slots,
|
||||
)
|
||||
|
||||
@ -893,15 +911,6 @@ class Scheduler:
|
||||
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
|
||||
)
|
||||
|
||||
def _can_swap_in(self, seq_group: SequenceGroup) -> bool:
|
||||
# Swapping in is considered decode.
|
||||
is_prefill = False
|
||||
|
||||
return self.block_manager.can_swap_in(
|
||||
seq_group=seq_group,
|
||||
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
|
||||
)
|
||||
|
||||
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
|
||||
# Schedule sequence groups.
|
||||
# This function call changes the internal states of the scheduler
|
||||
|
Reference in New Issue
Block a user