[V1] [P/D] Add Support for KV Load Failure Recovery (#19330)

Signed-off-by: David Ben-David <davidb@pliops.com>
Co-authored-by: David Ben-David <davidb@pliops.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
David Ben-David
2025-10-01 00:57:08 +03:00
committed by yewentao256
parent ef318228e7
commit 8328d39d40
24 changed files with 1039 additions and 86 deletions

View File

@ -0,0 +1,30 @@
# KV Load Failure Recovery Test
This example builds upon the `disaggregated-prefill-v1` example in `examples/offline_inference`.
It demonstrates vLLM's ability to recover from KV load failures in both synchronous and asynchronous loading modes. The goal is to verify that vLLM correctly identifies invalid KV blocks, reschedules the affected requests, and ensures successful and consistent output.
## Files
- `prefill_example.py` performs the prefill stage and saves KV data (same as in `disaggregated-prefill-v1`).
- `decode_example.py` performs the decode stage. Accepts:
- `--simulate-failure`: simulates KV load failure using a custom connector.
- `--async-load`: enables asynchronous KV loading mode.
- `rogue_shared_storage_connector.py` defines `RogueSharedStorageConnector`, a subclass of `SharedStorageConnector`, that simulates missing or corrupted external KV blocks by failing to load blocks for the first decode request.
- `run.sh` orchestrates the test: runs the prefill stage, then three decode stages:
1. Normal decode (baseline).
2. Decode with simulated sync KV load failure.
3. Decode with simulated async KV load failure.
Finally, it compares the output of the baseline with the recovered outputs to verify correctness.
## How It Works
- The test dynamically loads `RogueSharedStorageConnector` via `KVTransferConfig.kv_connector_module_path`, enabling controlled simulation of load failures without modifying the original connector.
- The decode stages that simulate failure are expected to trigger recovery logic in vLLM, resulting in the same output as the baseline decode.
- If recovery fails, the script prints a unified diff of the output mismatch and exits with error.
## Usage
```bash
./run.sh

View File

@ -0,0 +1,85 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
def read_prompts():
"""Read prompts from prefill_output.txt"""
prompts = []
try:
with open("prefill_output.txt") as f:
for line in f:
prompts.append(line.strip())
print(f"Loaded {len(prompts)} prompts from prefill_output.txt")
return prompts
except FileNotFoundError:
print("Error: prefill_output.txt file not found")
exit(-1)
def main():
prompts = read_prompts()
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
parser = argparse.ArgumentParser()
parser.add_argument(
"--simulate-failure", action="store_true", help="Simulate KV load failure."
)
parser.add_argument(
"--async-load", action="store_true", help="Simulate async KV load"
)
args = parser.parse_args()
if args.simulate_failure:
ktc = KVTransferConfig(
kv_connector="RogueSharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={
"shared_storage_path": "local_storage",
"async_load": args.async_load,
},
kv_connector_module_path="rogue_shared_storage_connector",
)
out_file = (
"async_decode_recovered_output.txt"
if args.async_load
else "sync_decode_recovered_output.txt"
)
else:
ktc = KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={
"shared_storage_path": "local_storage",
},
)
out_file = "decode_output.txt"
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
max_num_batched_tokens=64,
max_num_seqs=16,
kv_transfer_config=ktc,
)
outputs = llm.generate(prompts, sampling_params)
sep_str = "-" * 30
with open(out_file, "w", encoding="utf-8") as f:
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
out_str = f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}"
print(out_str)
print(sep_str)
f.write(out_str)
f.write(sep_str)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,58 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
def read_prompts():
context = "Hi " * 1000
context2 = "Hey " * 500
return [
context + "Hello, my name is",
context + "The capital of France is",
context2 + "Your name is",
context2 + "The capital of China is",
]
def main():
prompts = read_prompts()
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
kv_transfer_config=KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
),
) # , max_model_len=2048, max_num_batched_tokens=2048)
# 1ST generation (prefill instance)
outputs = llm.generate(
prompts,
sampling_params,
)
new_prompts = []
print("-" * 30)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
new_prompts.append(prompt + generated_text)
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 30)
# Write new_prompts to prefill_output.txt
with open("prefill_output.txt", "w") as f:
for prompt in new_prompts:
f.write(prompt + "\n")
print(f"Saved {len(new_prompts)} prompts to prefill_output.txt")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,145 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import (
SharedStorageConnector,
SharedStorageConnectorMetadata,
)
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
logger = logging.getLogger()
logging.basicConfig(level=logging.INFO)
@dataclass
class RogueSharedStorageConnectorMetadata(SharedStorageConnectorMetadata):
req_to_block_ids: dict[str, set[int]] = field(default_factory=dict)
@classmethod
def from_base(cls, base: SharedStorageConnectorMetadata):
return cls(requests=base.requests)
class RogueSharedStorageConnector(SharedStorageConnector):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._async_load = vllm_config.kv_transfer_config.get_from_extra_config(
"async_load", False
)
self._invalid_block_ids: set = None
self._seen_requests: set = set()
self._req_to_block_ids: dict[str, list[int]] = dict()
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
assert isinstance(connector_metadata, RogueSharedStorageConnectorMetadata)
index, failed_request = next(
(
(i, x)
for i, x in enumerate(connector_metadata.requests)
if not x.is_store
),
(None, None),
)
if index is not None:
del connector_metadata.requests[index]
self._invalid_block_ids = set(
(
failed_request.slot_mapping[:: self._block_size] // self._block_size
).tolist()
)
logger.info(
"Simulating failure to load all KV blocks for the "
"first load request. Total blocks: %d",
len(self._invalid_block_ids),
)
super().bind_connector_metadata(connector_metadata)
def clear_connector_metadata(self) -> None:
self._invalid_block_ids = None
super().clear_connector_metadata()
def start_load_kv(self, forward_context: ForwardContext, **kwargs) -> None:
if self._async_load and forward_context.attn_metadata is None:
# Bypass sanity check in super().start_load_kv
forward_context.attn_metadata = "None"
super().start_load_kv(forward_context, **kwargs)
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
if self._async_load:
meta = self._get_connector_metadata()
assert isinstance(meta, RogueSharedStorageConnectorMetadata)
if meta.req_to_block_ids:
return None, set(meta.req_to_block_ids)
return None, None
def get_block_ids_with_load_errors(self) -> set[int]:
return self._invalid_block_ids
def get_num_new_matched_tokens(
self,
request: Request,
num_computed_tokens: int,
) -> tuple[int, bool]:
if request.request_id in self._seen_requests:
return 0, False
self._seen_requests.add(request.request_id)
num_tokens, _ = super().get_num_new_matched_tokens(request, num_computed_tokens)
return num_tokens, self._async_load and num_tokens > 0
def update_state_after_alloc(
self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
):
"""
Update KVConnector state after block allocation.
If blocks were allocated, add to _requests_need_load,
such that we load the KVs in the next forward pass.
"""
super().update_state_after_alloc(request, blocks, num_external_tokens)
if num_external_tokens > 0:
self._req_to_block_ids[request.request_id] = blocks.get_block_ids()[0]
def build_connector_meta(
self,
scheduler_output: "SchedulerOutput",
) -> KVConnectorMetadata:
if not self._async_load:
base = super().build_connector_meta(scheduler_output)
meta = RogueSharedStorageConnectorMetadata.from_base(base)
else:
meta = RogueSharedStorageConnectorMetadata()
if self._requests_need_load:
for req_id, request in self._requests_need_load.items():
meta.add_request(
token_ids=request.prompt_token_ids,
block_ids=self._req_to_block_ids[req_id],
block_size=self._block_size,
is_store=False,
mm_hashes=[],
)
# Clear state
self._requests_need_load.clear()
meta.req_to_block_ids = self._req_to_block_ids
self._req_to_block_ids = dict()
return meta

View File

@ -0,0 +1,33 @@
#!/bin/bash
# Constants
SHARED_STORAGE_DIR="local_storage"
PREFILL_OUTPUT="prefill_output.txt"
DECODE_OUTPUT="decode_output.txt"
SYNC_DECODE_RECOVERED_OUTPUT="sync_decode_recovered_output.txt"
ASYNC_DECODE_RECOVERED_OUTPUT="async_decode_recovered_output.txt"
# Cleanup
rm -rf "$SHARED_STORAGE_DIR"
rm -f "$PREFILL_OUTPUT" "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"
# Run inference examples
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 prefill_example.py
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure --async-load
# Compare outputs
if ! cmp -s "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT"; then
echo "❌ Outputs differ: sync recovery failed."
diff -u "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT"
exit 1
fi
if ! cmp -s "$DECODE_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"; then
echo "❌ Outputs differ: async recovery failed."
diff -u "$DECODE_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"
exit 1
fi
echo "✅ Outputs match: recovery successful."

View File

@ -0,0 +1,341 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable
from unittest.mock import Mock
import pytest
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.request import Request, RequestStatus
from .utils import (create_model_runner_output, create_request,
create_scheduler, create_vllm_config)
def _make_get_num_new_matched_tokens(
req_num_new_matched_tokens: dict[str, int],
async_load,
) -> Callable[[Request, int], tuple[int, bool]]:
def get_num_new_matched_tokens(request: Request,
_: int) -> tuple[int, bool]:
value = req_num_new_matched_tokens.get(request.request_id, 0)
return value, async_load
return get_num_new_matched_tokens
@pytest.fixture
def scheduler():
vllm_config = create_vllm_config()
return create_scheduler(vllm_config)
@pytest.mark.parametrize(
"num_prompt_blocks,"
"num_external_computed_blocks,"
"invalid_block_idxs",
[
(100, 99, {0, 98}),
(100, 99, {50, 98}),
(100, 99, {98}),
],
)
def test_async_load_failure(
scheduler: Scheduler,
num_prompt_blocks: int,
num_external_computed_blocks: int,
invalid_block_idxs: set[int],
):
assert num_prompt_blocks >= num_external_computed_blocks
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
num_external_computed_tokens = (num_external_computed_blocks *
scheduler.block_size)
request1 = create_request(num_tokens=num_prompt_tokens)
scheduler.add_request(request=request1)
request2 = create_request(num_tokens=num_prompt_tokens)
scheduler.add_request(request=request2)
request3 = create_request(num_tokens=num_prompt_tokens)
scheduler.add_request(request=request3)
# Mock KV connector method.
# req_id -> num_external_computed_tokens
req_num_new_matched_tokens = {
request1.request_id: num_external_computed_tokens,
request2.request_id: num_external_computed_tokens,
request3.request_id: num_external_computed_tokens,
}
scheduler.connector = Mock()
scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens,
async_load=True))
scheduler.connector.take_events.return_value = ()
scheduler_output = scheduler.schedule()
assert len(scheduler.waiting) == 3
for request in scheduler.waiting:
assert request.num_computed_tokens == 0
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
# Simulate a failure in loading some of request2 blocks.
(req2_block_ids, ) = scheduler.kv_cache_manager.get_block_ids(
request2.request_id)
invalid_block_ids = {req2_block_ids[i] for i in invalid_block_idxs}
model_runner_output = create_model_runner_output(
reqs=[],
finished_recving={request1.request_id, request3.request_id},
invalid_block_ids=invalid_block_ids,
use_eos=True)
scheduler.update_from_output(scheduler_output, model_runner_output)
min_invalid_block_idx = min(invalid_block_idxs)
assert len(scheduler.waiting) == 3
for request in scheduler.waiting:
if request.request_id == request2.request_id:
assert request.num_computed_tokens == (min_invalid_block_idx *
scheduler.block_size)
else:
assert request.num_computed_tokens == 0
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert scheduler.failed_recving_kv_req_ids == {request2.request_id}
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
@pytest.mark.parametrize(
"num_prompt_blocks,"
"num_external_computed_blocks,"
"invalid_block_idxs",
[
(100, 99, {0, 98}),
(100, 99, {50, 98}),
(100, 99, {98}),
],
)
def test_sync_load_failure(
scheduler: Scheduler,
num_prompt_blocks: int,
num_external_computed_blocks: int,
invalid_block_idxs: set[int],
):
assert num_prompt_blocks >= num_external_computed_blocks
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
num_external_computed_tokens = (num_external_computed_blocks *
scheduler.block_size)
request1 = create_request(num_tokens=num_prompt_tokens)
scheduler.add_request(request=request1)
request2 = create_request(num_tokens=num_prompt_tokens)
scheduler.add_request(request=request2)
request3 = create_request(num_tokens=num_prompt_tokens)
scheduler.add_request(request=request3)
# Mock KV connector method.
# req_id -> num_external_computed_tokens
req_num_new_matched_tokens = {
request1.request_id: num_external_computed_tokens,
request2.request_id: num_external_computed_tokens,
request3.request_id: num_external_computed_tokens,
}
scheduler.connector = Mock()
scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens,
async_load=False))
scheduler.connector.request_finished.return_value = (False, None)
scheduler.connector.take_events.return_value = ()
scheduler_output = scheduler.schedule()
# req_id -> num_computed_tokens
expected_computed_tokens = {
request1.request_id: num_external_computed_tokens,
request2.request_id: num_external_computed_tokens,
request3.request_id: num_external_computed_tokens,
}
assert len(scheduler.running) == 3
assert len(scheduler_output.scheduled_new_reqs) == 3
for request in scheduler_output.scheduled_new_reqs:
assert request.num_computed_tokens == expected_computed_tokens[
request.req_id]
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
# Simulate a failure in loading some of request2 blocks.
req2_block_ids = scheduler_output.scheduled_new_reqs[1].block_ids[0]
invalid_block_ids = {req2_block_ids[i] for i in invalid_block_idxs}
model_runner_output = create_model_runner_output(
[request1, request2, request3],
invalid_block_ids=invalid_block_ids,
use_eos=True)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == request2.request_id
assert scheduler.running[0].num_computed_tokens == (
min(invalid_block_idxs) * scheduler.block_size)
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
assert scheduler.connector.request_finished.call_count == 2
@pytest.mark.parametrize(
"num_prompt_blocks,"
"num_external_computed_blocks,"
"num_common_prefix_blocks,"
"invalid_block_idxs",
[
(100, 99, 50, {0, 49}),
(100, 99, 50, {25, 49}),
(100, 99, 50, {49}),
],
)
def test_sync_load_failure_with_shared_blocks(
scheduler: Scheduler,
num_prompt_blocks: int,
num_external_computed_blocks: int,
num_common_prefix_blocks: int,
invalid_block_idxs: set[int],
):
assert (num_prompt_blocks >= num_external_computed_blocks >=
num_common_prefix_blocks)
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
num_external_computed_tokens = (num_external_computed_blocks *
scheduler.block_size)
common_prefix_len = num_common_prefix_blocks * scheduler.block_size
request1 = create_request(num_tokens=num_prompt_tokens,
common_prefix_len=common_prefix_len)
scheduler.add_request(request=request1)
request2 = create_request(num_tokens=num_prompt_tokens,
common_prefix_len=common_prefix_len)
scheduler.add_request(request=request2)
# Mock KV connector method.
# req_id -> num_external_computed_tokens
req_num_new_matched_tokens = {
request1.request_id: num_external_computed_tokens,
}
scheduler.connector = Mock()
scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens,
async_load=False))
scheduler.connector.take_events.return_value = ()
scheduler_output = scheduler.schedule()
# req_id -> num_computed_tokens
expected_computed_tokens = {
request1.request_id: num_external_computed_tokens,
request2.request_id: common_prefix_len,
}
assert len(scheduler.running) == 2
assert len(scheduler_output.scheduled_new_reqs) == 2
for request in scheduler_output.scheduled_new_reqs:
assert request.num_computed_tokens == expected_computed_tokens[
request.req_id]
assert scheduler.connector.get_num_new_matched_tokens.call_count == 2
# Simulate a failure in loading some of the shared blocks.
req1_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
invalid_block_ids = {req1_block_ids[i] for i in invalid_block_idxs}
model_runner_output = create_model_runner_output(
[request1, request2],
invalid_block_ids=invalid_block_ids,
use_eos=True)
scheduler.update_from_output(scheduler_output, model_runner_output)
# req_id -> num_computed_tokens
# all the common prefix blocks will be computed by request1
expected_computed_tokens = {
request1.request_id: min(invalid_block_idxs) * scheduler.block_size,
request2.request_id: common_prefix_len,
}
assert len(scheduler.running) == 2
for request in scheduler.running:
assert request.num_computed_tokens == expected_computed_tokens[
request.request_id]
assert scheduler.connector.get_num_new_matched_tokens.call_count == 2
@pytest.mark.parametrize(
"num_prompt_blocks,"
"num_external_computed_blocks,"
"invalid_block_idxs",
[
(100, 99, {0, 50, 98}),
(100, 99, {98, 50, 0}),
],
)
def test_async_progressive_load_failure(
scheduler: Scheduler,
num_prompt_blocks: int,
num_external_computed_blocks: int,
invalid_block_idxs: set[int],
):
assert num_prompt_blocks >= num_external_computed_blocks
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
num_external_computed_tokens = (num_external_computed_blocks *
scheduler.block_size)
request = create_request(num_tokens=num_prompt_tokens)
scheduler.add_request(request=request)
# Mock KV connector method.
# req_id -> num_external_computed_tokens
req_num_new_matched_tokens = {
request.request_id: num_external_computed_tokens,
}
scheduler.connector = Mock()
scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens,
async_load=True))
scheduler.connector.take_events.return_value = ()
scheduler_output = scheduler.schedule()
assert len(scheduler.waiting) == 1
assert scheduler.waiting.peek_request().request_id == request.request_id
assert request.num_computed_tokens == 0
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert scheduler.connector.get_num_new_matched_tokens.call_count == 1
min_invalid_block_idx = max(invalid_block_idxs) + 1
# Simulate failures when progressively loading request blocks.
for invalid_block_idx in invalid_block_idxs:
(req_block_ids, ) = scheduler.kv_cache_manager.get_block_ids(
request.request_id)
invalid_block_ids = {req_block_ids[invalid_block_idx]}
model_runner_output = create_model_runner_output(
reqs=[],
finished_recving=set(),
invalid_block_ids=invalid_block_ids,
use_eos=True)
scheduler.update_from_output(scheduler_output, model_runner_output)
min_invalid_block_idx = min(min_invalid_block_idx, invalid_block_idx)
assert len(scheduler.waiting) == 1
assert scheduler.waiting.peek_request(
).request_id == request.request_id
assert request.num_computed_tokens == (min_invalid_block_idx *
scheduler.block_size)
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert scheduler.failed_recving_kv_req_ids == {request.request_id}
assert scheduler.connector.get_num_new_matched_tokens.call_count == 1

View File

@ -281,8 +281,8 @@ class RequestRunner:
model_runner_output = create_model_runner_output(
reqs=self.scheduler.running,
finished_sending=list(finished_sending),
finished_recving=list(finished_recving),
finished_sending=finished_sending,
finished_recving=finished_recving,
token_id=token_id)
if self.scheduler.running:

View File

@ -15,26 +15,26 @@ class DummyModelRunnerOutput(ModelRunnerOutput):
def __init__(self,
finished_sending: Optional[set[str]] = None,
finished_recving: Optional[set[str]] = None):
finished_recving: Optional[set[str]] = None,
invalid_block_ids: Optional[set[int]] = None):
self.kv_connector_output = KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
)
invalid_block_ids=invalid_block_ids or set())
def __repr__(self):
return (
f"DummyModelRunnerOutput("
f"finished_sending={self.kv_connector_output.finished_sending},"
f"finished_recving={self.kv_connector_output.finished_recving})")
f"finished_recving={self.kv_connector_output.finished_recving})"
f"invalid_block_ids={self.kv_connector_output.invalid_block_ids})")
def test_aggregate_workers_output():
aggregator = KVOutputAggregator(world_size=2)
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
finished_recving={'req2'})
output2 = DummyModelRunnerOutput(finished_sending=None,
finished_recving=None)
output1 = DummyModelRunnerOutput()
output2 = DummyModelRunnerOutput()
aggregated = aggregator.aggregate([output1, output2])
@ -42,11 +42,22 @@ def test_aggregate_workers_output():
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending is None
assert aggregated.finished_recving is None
assert not aggregated.invalid_block_ids
output1 = DummyModelRunnerOutput(finished_sending=None,
finished_recving=None)
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
finished_recving=None)
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
finished_recving={'req2'})
output2 = DummyModelRunnerOutput(invalid_block_ids={1})
aggregated = aggregator.aggregate([output1, output2])
assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending is None
assert aggregated.finished_recving is None
assert aggregated.invalid_block_ids == {1}
output1 = DummyModelRunnerOutput(invalid_block_ids={2})
output2 = DummyModelRunnerOutput(finished_sending={'req1'})
aggregated = aggregator.aggregate([output1, output2])
@ -54,11 +65,11 @@ def test_aggregate_workers_output():
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending == {'req1'}
assert aggregated.finished_recving is None
assert aggregated.invalid_block_ids == {2}
output1 = DummyModelRunnerOutput(finished_sending=None,
finished_recving=None)
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
finished_recving={'req2'})
output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4})
output2 = DummyModelRunnerOutput(finished_recving={'req2'},
invalid_block_ids={4, 5})
aggregated = aggregator.aggregate([output1, output2])
@ -66,6 +77,7 @@ def test_aggregate_workers_output():
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending is None
assert aggregated.finished_recving == {'req2'}
assert aggregated.invalid_block_ids == {3, 4, 5}
def test_async_aggregate_workers_output():
@ -75,10 +87,8 @@ def test_async_aggregate_workers_output():
future2: Future[DummyModelRunnerOutput] = Future()
result_future = aggregator.async_aggregate([future1, future2])
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
finished_recving={'req2'})
output2 = DummyModelRunnerOutput(finished_sending=None,
finished_recving=None)
output1 = DummyModelRunnerOutput()
output2 = DummyModelRunnerOutput()
future1.set_result(output1)
future2.set_result(output2)
@ -88,15 +98,32 @@ def test_async_aggregate_workers_output():
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending is None
assert aggregated.finished_recving is None
assert not aggregated.invalid_block_ids
future1 = Future()
future2 = Future()
result_future = aggregator.async_aggregate([future1, future2])
output1 = DummyModelRunnerOutput(finished_sending=None,
finished_recving=None)
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
finished_recving=None)
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
finished_recving={'req2'})
output2 = DummyModelRunnerOutput(invalid_block_ids={1})
future1.set_result(output1)
future2.set_result(output2)
assert result_future.done()
aggregated = result_future.result()
assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending is None
assert aggregated.finished_recving is None
assert aggregated.invalid_block_ids == {1}
future1 = Future()
future2 = Future()
result_future = aggregator.async_aggregate([future1, future2])
output1 = DummyModelRunnerOutput(invalid_block_ids={2})
output2 = DummyModelRunnerOutput(finished_sending={'req1'})
future1.set_result(output1)
future2.set_result(output2)
@ -106,15 +133,15 @@ def test_async_aggregate_workers_output():
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending == {'req1'}
assert aggregated.finished_recving is None
assert aggregated.invalid_block_ids == {2}
future1 = Future()
future2 = Future()
result_future = aggregator.async_aggregate([future1, future2])
output1 = DummyModelRunnerOutput(finished_sending=None,
finished_recving=None)
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
finished_recving={'req2'})
output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4})
output2 = DummyModelRunnerOutput(finished_recving={'req2'},
invalid_block_ids={4, 5})
future1.set_result(output1)
future2.set_result(output2)
@ -124,3 +151,4 @@ def test_async_aggregate_workers_output():
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending is None
assert aggregated.finished_recving == {'req2'}
assert aggregated.invalid_block_ids == {3, 4, 5}

View File

@ -92,7 +92,7 @@ def test_basic_lifecycle():
# (3b): execute_model()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_sending=[request_id])
finished_sending={request_id})
# (3c): update_from_output()
scheduler.update_from_output(scheduler_output, model_runner_output)
@ -139,7 +139,7 @@ def test_short_prompt_lifecycle():
scheduler_output = scheduler.schedule()
# Use create_model_runner_output to pass kv_connector_output along
model_runner_output = create_model_runner_output(
reqs=[request], finished_sending=[request.request_id])
reqs=[request], finished_sending={request.request_id})
scheduler.update_from_output(scheduler_output, model_runner_output)
assert_scheduler_empty(scheduler)
@ -195,6 +195,6 @@ def test_prefix_cache_lifecycle():
scheduler_output = scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_sending=[request_remote.request_id])
finished_sending={request_remote.request_id})
scheduler.update_from_output(scheduler_output, model_runner_output)
assert_scheduler_empty(scheduler)

View File

@ -78,7 +78,7 @@ def test_basic_lifecycle():
# (2b): forward(): request finishes recv.
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_recving=[request_id])
finished_recving={request_id})
# (2c): update_from_output():
engine_core_outputs = scheduler.update_from_output(scheduler_output,
@ -197,7 +197,7 @@ def test_interleaved_lifecycle():
model_runner_output = create_model_runner_output(
[request_local_a, request_local_b],
finished_recving=[request_remote.request_id])
finished_recving={request_remote.request_id})
scheduler.update_from_output(scheduler_output, model_runner_output)
# STEP 5: RECVed KVs are sent to ModelRunner.
@ -246,16 +246,16 @@ def test_no_spurious_prefix_caching():
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
common_prefix_len=NUM_TOKENS,
do_remote_prefill=True,
use_all_1s_for_prompt_tokens=True,
)
request_local = create_request(
request_id=2,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
common_prefix_len=NUM_TOKENS,
do_remote_prefill=False,
use_all_1s_for_prompt_tokens=True,
)
# Schedule the remote prefill request. This should not
@ -322,7 +322,7 @@ def test_full_block_prompt():
scheduler_output = scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_recving=[request_id])
finished_recving={request_id})
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.waiting) == 1
assert (request_id in scheduler.finished_recving_kv_req_ids)
@ -402,7 +402,7 @@ def test_cannot_schedule_after_recv():
# Step 3: finish recving (5 blocks in use)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(
reqs=[request_normal], finished_recving=[request_remote.request_id])
reqs=[request_normal], finished_recving={request_remote.request_id})
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
@ -516,7 +516,7 @@ def test_cannot_recv():
# Step 5: finish recving (5 blocks in use)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(
reqs=[], finished_recving=[request_remote.request_id])
reqs=[], finished_recving={request_remote.request_id})
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
from collections import defaultdict
from itertools import count
from typing import Any, Callable, Optional
import torch
@ -61,12 +62,15 @@ def create_vllm_config(
max_num_seqs: int = 16,
max_num_batched_tokens: int = 64,
block_size: int = 16,
max_model_len: int = 10000,
enable_chunked_prefill: bool = True,
) -> VllmConfig:
"""Initialize VllmConfig For Testing."""
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_num_batched_tokens,
max_model_len=max_model_len,
enable_chunked_prefill=enable_chunked_prefill,
)
model_config = ModelConfig(
model=model,
@ -117,19 +121,27 @@ def create_scheduler(
)
_request_count = count(1)
_none_hash_initialized = False
def create_request(request_id: int,
num_tokens: int = 10,
max_tokens: int = 16,
do_remote_decode: bool = False,
do_remote_prefill: bool = False,
use_all_1s_for_prompt_tokens: bool = False,
num_remote_blocks: int = 3,
block_size: int = 16,
hash_fn: Callable = sha256) -> Request:
def create_request(
request_id: Optional[int] = None,
num_tokens: int = 10,
common_prefix_len=0,
max_tokens: int = 16,
do_remote_decode: bool = False,
do_remote_prefill: bool = False,
num_remote_blocks: int = 3,
block_size: int = 16,
hash_fn: Callable = sha256,
) -> Request:
"""Make dummy request for testing."""
assert num_tokens >= common_prefix_len >= 0
if request_id is None:
request_id = next(_request_count)
global _none_hash_initialized
if not _none_hash_initialized:
init_none_hash(hash_fn)
@ -153,10 +165,9 @@ def create_request(request_id: int,
max_tokens = 1 if do_remote_decode else max_tokens
sampling_params = SamplingParams(max_tokens=max_tokens)
if use_all_1s_for_prompt_tokens:
prompt_token_ids = [1] * num_tokens
else:
prompt_token_ids = [i * request_id for i in range(num_tokens)]
common_prefix = [1] * common_prefix_len if common_prefix_len > 0 else []
suffix = [i * request_id for i in range(num_tokens - common_prefix_len)]
prompt_token_ids = common_prefix + suffix
req = Request(
request_id=f"id-{request_id}",
@ -173,8 +184,9 @@ def create_request(request_id: int,
def create_model_runner_output(
reqs: list[Request],
finished_sending: Optional[list[str]] = None,
finished_recving: Optional[list[str]] = None,
finished_sending: Optional[set[str]] = None,
finished_recving: Optional[set[str]] = None,
invalid_block_ids: Optional[set[int]] = None,
use_eos: bool = False,
token_id: int = 0,
) -> ModelRunnerOutput:
@ -189,10 +201,11 @@ def create_model_runner_output(
sampled_token_ids = [[sampled_token] for _ in req_ids]
kv_connector_output = None if (
finished_sending is None
and finished_recving is None) else KVConnectorOutput(
finished_sending is None and finished_recving is None
and invalid_block_ids is None) else KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
invalid_block_ids=invalid_block_ids or set(),
)
# Make output data structure.

View File

@ -250,6 +250,7 @@ def test_update_states_request_resumed(model_runner, dist_init):
new_token_ids=[[]],
new_block_ids=([[0]], ),
num_computed_tokens=[0],
num_output_tokens=[0],
)
scheduler_output = SchedulerOutput(

View File

@ -117,7 +117,7 @@ def get_kv_connector_cache_layout():
class KVOutputAggregator:
"""Utility class to aggregate the output of all workers into a single
"""Utility class to aggregate the output of all workers into a single
output corresponding to Rank 0 for scheduler."""
def __init__(self, world_size: int):
@ -143,6 +143,7 @@ class KVOutputAggregator:
finished_sending = set[str]()
finished_recving = set[str]()
aggregated_kv_connector_stats = None
invalid_block_ids = set[int]()
for model_runner_output in outputs:
output = model_runner_output.kv_connector_output
if not output:
@ -165,6 +166,8 @@ class KVOutputAggregator:
aggregated_kv_connector_stats = \
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
invalid_block_ids |= output.invalid_block_ids
# select output of the worker specified by output_rank
output = outputs[output_rank]
@ -172,6 +175,7 @@ class KVOutputAggregator:
finished_sending=finished_sending or None,
finished_recving=finished_recving or None,
kv_connector_stats=aggregated_kv_connector_stats or None,
invalid_block_ids=invalid_block_ids,
)
return output

View File

@ -229,6 +229,26 @@ class KVConnectorBase_V1(ABC):
"""
return None, None
def get_block_ids_with_load_errors(self) -> set[int]:
"""
Get the set of block IDs that failed to load.
Returns:
Set of block IDs that encountered load errors.
Empty set if no load errors occurred.
Notes:
- Applies to both sync- and async-loading requests.
- Async loading: failed blocks may be reported in any forward pass
up to and including the pass where the request ID is returned by
`get_finished()`. Even if failures occur, the request must still
be reported via `get_finished()`, and the failed block IDs must
appear here no later than that same pass.
- Sync loading: failed blocks should be reported in the forward
pass in which they are detected.
"""
return set()
def shutdown(self):
"""
Shutdown the connector. This is called when the worker process
@ -264,14 +284,21 @@ class KVConnectorBase_V1(ABC):
Returns:
A tuple with the following elements:
- An optional number of tokens that can be loaded from the
external KV cache beyond what is already computed.
- An optional number of tokens that can be loaded from the
external KV cache beyond what is already computed.
If None, it means that the connector needs more time to
determine the number of matched tokens, and the scheduler
should query for this request again later.
- `True` if external KV cache tokens will be loaded
asynchronously (between scheduler steps). Must be
'False' if the first element is 0.
Notes:
The connector should only consider the largest prefix of prompt-
tokens for which KV cache is actually available at the time of the
call. If the cache cannot be loaded for some tokens (e.g., due to
connectivity issues or eviction), those tokens must not be taken
into account.
"""
pass

View File

@ -189,6 +189,12 @@ class MultiConnector(KVConnectorBase_V1):
return finished_sending or None, finished_recving or None
def get_block_ids_with_load_errors(self) -> set[int]:
agg_block_ids: set[int] = set()
for c in self._connectors:
agg_block_ids |= c.get_block_ids_with_load_errors()
return agg_block_ids
# ==============================
# Scheduler-side methods
# ==============================

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import os
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional
import safetensors
@ -55,10 +55,7 @@ class ReqMeta:
@dataclass
class SharedStorageConnectorMetadata(KVConnectorMetadata):
requests: list[ReqMeta]
def __init__(self):
self.requests = []
requests: list[ReqMeta] = field(default_factory=list)
def add_request(
self,

View File

@ -211,7 +211,7 @@ class BlockPool:
block_size: Number of tokens in each block.
kv_cache_group_id: The id of the KV cache group.
"""
if num_cached_blocks == num_full_blocks:
if num_cached_blocks >= num_full_blocks:
return
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
assert len(request.block_hashes) >= num_full_blocks

View File

@ -101,6 +101,7 @@ class CachedRequestData:
new_token_ids: list[list[int]]
new_block_ids: list[Optional[tuple[list[int], ...]]]
num_computed_tokens: list[int]
num_output_tokens: list[int]
@property
def num_reqs(self) -> int:
@ -114,6 +115,7 @@ class CachedRequestData:
new_token_ids=[],
new_block_ids=[],
num_computed_tokens=[],
num_output_tokens=[],
)

View File

@ -133,6 +133,7 @@ class Scheduler(SchedulerInterface):
# KV Connector: requests in process of async KV loading or recving
self.finished_recving_kv_req_ids: set[str] = set()
self.failed_recving_kv_req_ids: set[str] = set()
# Encoder-related.
# Calculate encoder cache size if applicable
@ -671,6 +672,7 @@ class Scheduler(SchedulerInterface):
new_token_ids: list[list[int]] = []
new_block_ids: list[Optional[tuple[list[int], ...]]] = []
num_computed_tokens: list[int] = []
num_output_tokens: list[int] = []
use_connector = self.connector is not None
for req in itertools.chain(running_reqs, resumed_reqs):
@ -695,6 +697,7 @@ class Scheduler(SchedulerInterface):
new_block_ids.append(
req_to_new_blocks[req_id].get_block_ids(allow_none=True))
num_computed_tokens.append(req.num_computed_tokens)
num_output_tokens.append(len(req.output_token_ids))
# Because resumed_reqs is usually empty, it is more efficient to do
# in-place appending so that we don't need to allocate a new list.
resumed_from_preemption = [False] * len(running_reqs)
@ -706,6 +709,7 @@ class Scheduler(SchedulerInterface):
new_token_ids=new_token_ids,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
num_output_tokens=num_output_tokens,
)
def _try_schedule_encoder_inputs(
@ -878,6 +882,14 @@ class Scheduler(SchedulerInterface):
kv_connector_stats = (kv_connector_output.kv_connector_stats
if kv_connector_output else None)
failed_kv_load_req_ids = None
if kv_connector_output and kv_connector_output.invalid_block_ids:
# These blocks contain externally computed tokens that failed to
# load. Identify affected requests and adjust their computed token
# count to trigger recomputation of the invalid blocks.
failed_kv_load_req_ids = self._handle_invalid_blocks(
kv_connector_output.invalid_block_ids)
# NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more,
# the below loop can be a performance bottleneck. We should do our best
# to avoid expensive operations inside the loop.
@ -885,6 +897,9 @@ class Scheduler(SchedulerInterface):
stopped_preempted_reqs: set[Request] = set()
for req_id, num_tokens_scheduled in num_scheduled_tokens.items():
assert num_tokens_scheduled > 0
if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids:
# Skip requests that were recovered from KV load failure
continue
request = self.requests.get(req_id)
if request is None:
# The request is already finished. This can happen if the
@ -988,9 +1003,8 @@ class Scheduler(SchedulerInterface):
self.waiting.remove_requests(stopped_preempted_reqs)
# KV Connector: update state for finished KV Transfers.
if model_runner_output.kv_connector_output:
self._update_from_kv_xfer_finished(
model_runner_output.kv_connector_output)
if kv_connector_output:
self._update_from_kv_xfer_finished(kv_connector_output)
# Create EngineCoreOutputs for all clients that have requests with
# outputs in this step.
@ -1252,18 +1266,33 @@ class Scheduler(SchedulerInterface):
if request.request_id not in self.finished_recving_kv_req_ids:
return False
# Now that the blocks are ready, actually cache them.
(block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id)
num_computed_tokens = len(block_ids) * self.block_size
# Handle the case where num request tokens less than one block.
num_computed_tokens = min(num_computed_tokens, request.num_tokens)
if num_computed_tokens == request.num_tokens:
num_computed_tokens -= 1
# This will cache the blocks iff caching is enabled.
self.kv_cache_manager.cache_blocks(request, num_computed_tokens)
if request.request_id in self.failed_recving_kv_req_ids:
# Request had KV load failures; num_computed_tokens was already
# updated in _update_requests_with_invalid_blocks
if request.num_computed_tokens:
# Cache any valid computed tokens.
self.kv_cache_manager.cache_blocks(request,
request.num_computed_tokens)
else:
# No valid computed tokens, release allocated blocks.
# There may be a local cache hit on retry.
self.kv_cache_manager.free(request)
# Update the request state for scheduling.
request.num_computed_tokens = num_computed_tokens
self.failed_recving_kv_req_ids.remove(request.request_id)
else:
# Now that the blocks are ready, actually cache them.
(block_ids, ) = self.kv_cache_manager.get_block_ids(
request.request_id)
num_computed_tokens = len(block_ids) * self.block_size
# Handle the case where num request tokens less than one block.
num_computed_tokens = min(num_computed_tokens, request.num_tokens)
if num_computed_tokens == request.num_tokens:
num_computed_tokens -= 1
# This will cache the blocks iff caching is enabled.
self.kv_cache_manager.cache_blocks(request, num_computed_tokens)
# Update the request state for scheduling.
request.num_computed_tokens = num_computed_tokens
# Return that we are ready.
self.finished_recving_kv_req_ids.remove(request.request_id)
@ -1296,3 +1325,134 @@ class Scheduler(SchedulerInterface):
"but the request is already freed.", req_id)
else:
self._free_blocks(self.requests[req_id])
def _update_requests_with_invalid_blocks(
self, requests: Iterable[Request],
invalid_block_ids: set[int]) -> tuple[set[str], int]:
"""
Identify and update requests affected by invalid KV cache blocks.
This method scans the given requests, detects those with invalid blocks
and adjusts their `num_computed_tokens` to the longest valid prefix.
For observability, it also accumulates the total number of tokens that
will need to be recomputed across all affected requests.
Args:
requests: The set of requests to scan for invalid blocks.
invalid_block_ids: IDs of invalid blocks.
Returns:
tuple:
- affected_req_ids (set[str]): IDs of requests impacted by
invalid blocks.
- total_affected_tokens (int): Total number of tokens that must
be recomputed across all affected requests (for observability).
"""
affected_req_ids: set[str] = set()
total_affected_tokens = 0
# If a block is invalid and shared by multiple requests in the batch,
# these requests must be rescheduled, but only the first will recompute
# it. This set tracks blocks already marked for recomputation.
marked_invalid_block_ids: set[int] = set()
for request in requests:
is_affected = False
marked_invalid_block = False
req_id = request.request_id
# TODO (davidb): add support for hybrid memory allocator
(req_block_ids, ) = self.kv_cache_manager.get_block_ids(req_id)
# We iterate only over blocks that may contain externally computed
# tokens
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
# Async loading. If num_computed_tokens is set it implies we
# already processed some block failures for it in a prior step
req_num_computed_tokens = (
request.num_computed_tokens if req_id
in self.failed_recving_kv_req_ids else len(req_block_ids) *
self.block_size)
else:
# Sync loading. num_computed_tokens includes new tokens
req_num_computed_tokens = request.num_cached_tokens
req_num_computed_blocks = (req_num_computed_tokens +
self.block_size - 1) // self.block_size
for idx, block_id in zip(range(req_num_computed_blocks),
req_block_ids):
if block_id not in invalid_block_ids:
continue
is_affected = True
if block_id in marked_invalid_block_ids:
# This invalid block is shared with a previous request
# and was already marked for recomputation.
# This means this request can still consider this block
# as computed when rescheduled.
# Currently this only applies to sync loading; Async
# loading does not yet support block sharing
continue
marked_invalid_block_ids.add(block_id)
if marked_invalid_block:
# This request has already marked an invalid block for
# recomputation and updated its num_computed_tokens.
continue
marked_invalid_block = True
# Truncate the computed tokens at the first failed block
request.num_computed_tokens = idx * self.block_size
total_affected_tokens += (req_num_computed_tokens -
request.num_computed_tokens)
if is_affected:
if not marked_invalid_block:
# All invalid blocks of this request are shared with
# previous requests and will be recomputed by them.
# Revert to considering only cached tokens as computed.
# Currently this only applies to sync loading; Async
# loading does not yet support block sharing
total_affected_tokens += (request.num_computed_tokens -
request.num_cached_tokens)
request.num_computed_tokens = request.num_cached_tokens
affected_req_ids.add(request.request_id)
return (affected_req_ids, total_affected_tokens)
def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]:
total_requests_to_reschedule = 0
total_tokens_to_reschedule = 0
# --- Handle async KV loads (WAITING_FOR_REMOTE_KVS) ---
async_load_reqs = (
req for req in self.waiting
if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS)
async_affected_req_ids, num_tokens_to_reschedule = (
self._update_requests_with_invalid_blocks(async_load_reqs,
invalid_block_ids))
total_requests_to_reschedule += len(async_affected_req_ids)
total_tokens_to_reschedule += num_tokens_to_reschedule
# Mark requests with async KV load failures; they will be rescheduled
# once loading completes
self.failed_recving_kv_req_ids |= async_affected_req_ids
# --- Handle sync KV loads (running requests) ---
sync_affected_req_ids, num_tokens_to_reschedule = (
self._update_requests_with_invalid_blocks(self.running,
invalid_block_ids))
total_requests_to_reschedule += len(sync_affected_req_ids)
total_tokens_to_reschedule += num_tokens_to_reschedule
if total_requests_to_reschedule:
logger.warning(
"Recovered from KV load failure: "
"%d request(s) rescheduled (%d tokens affected).",
total_requests_to_reschedule, total_tokens_to_reschedule)
# Return the IDs of affected running requests to skip in
# update_from_output.
return sync_affected_req_ids

View File

@ -142,6 +142,9 @@ class SingleTypeKVCacheManager(ABC):
num_cached_blocks = self.num_cached_block[request.request_id]
num_full_blocks = num_tokens // self.block_size
if num_cached_blocks >= num_full_blocks:
return
self.block_pool.cache_full_blocks(
request=request,
blocks=self.req_to_blocks[request.request_id],

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, NamedTuple, Optional, Union
import torch
@ -87,10 +87,13 @@ class KVConnectorOutput:
finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None
kv_connector_stats: Optional["KVConnectorStats"] = None
# IDs of externally computed KV blocks that failed to load.
# Requests referencing these blocks should be rescheduled to recompute them.
invalid_block_ids: set[int] = field(default_factory=set)
def is_empty(self):
return (not self.finished_sending and not self.finished_recving
and not self.kv_connector_stats)
and not self.kv_connector_stats and not self.invalid_block_ids)
# ModelRunnerOutput is serialized and sent to the scheduler process.

View File

@ -634,8 +634,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_computed_tokens = req_data.num_computed_tokens[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]
num_output_tokens = req_data.num_output_tokens[i]
# Update the cached states.
req_state.num_computed_tokens = num_computed_tokens
if not is_last_rank:
@ -653,6 +655,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
elif num_new_tokens > 0:
req_state.output_token_ids.extend(
new_token_ids[-num_new_tokens:])
elif num_output_tokens < len(req_state.output_token_ids):
# Some output tokens were discarded due to a sync-KV-load
# failure. Align the cached state.
del req_state.output_token_ids[num_output_tokens:]
req_index = self.input_batch.req_id_to_index.get(req_id)
if req_index is not None:
old_end_idx = self.input_batch.num_tokens_no_spec[
req_index]
end_idx = self.input_batch.num_prompt_tokens[
req_index] + num_output_tokens
self.input_batch.num_tokens[req_index] = end_idx
self.input_batch.num_tokens_no_spec[req_index] = end_idx
self.input_batch.is_token_ids[req_index,
end_idx:old_end_idx] = False
# Update the block IDs.
if not resumed_from_preemption:

View File

@ -464,8 +464,7 @@ class Worker(WorkerBase):
# In case of PP with kv transfer, we need to pass through the
# kv_connector_output
if (not kv_connector_output.finished_sending
and not kv_connector_output.finished_recving):
if kv_connector_output.is_empty():
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)

View File

@ -75,8 +75,7 @@ class KVConnectorModelRunnerMixin:
scheduler_output, wait_for_save=False) as kv_connector_output:
pass
if (not kv_connector_output.finished_sending
and not kv_connector_output.finished_recving):
if kv_connector_output.is_empty():
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
@ -120,6 +119,8 @@ class KVConnectorModelRunnerMixin:
output.finished_sending, output.finished_recving = (
kv_connector.get_finished(scheduler_output.finished_req_ids))
output.invalid_block_ids = (
kv_connector.get_block_ids_with_load_errors())
output.kv_connector_stats = KVConnectorModelRunnerMixin.\
get_kv_connector_stats()