mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[V1] [P/D] Add Support for KV Load Failure Recovery (#19330)
Signed-off-by: David Ben-David <davidb@pliops.com> Co-authored-by: David Ben-David <davidb@pliops.com>
This commit is contained in:
@ -0,0 +1,30 @@
|
||||
# KV Load Failure Recovery Test
|
||||
|
||||
This example builds upon the `disaggregated-prefill-v1` example in `examples/offline_inference`.
|
||||
|
||||
It demonstrates vLLM's ability to recover from KV load failures in both synchronous and asynchronous loading modes. The goal is to verify that vLLM correctly identifies invalid KV blocks, reschedules the affected requests, and ensures successful and consistent output.
|
||||
|
||||
## Files
|
||||
|
||||
- `prefill_example.py` – performs the prefill stage and saves KV data (same as in `disaggregated-prefill-v1`).
|
||||
- `decode_example.py` – performs the decode stage. Accepts:
|
||||
- `--simulate-failure`: simulates KV load failure using a custom connector.
|
||||
- `--async-load`: enables asynchronous KV loading mode.
|
||||
- `rogue_shared_storage_connector.py` – defines `RogueSharedStorageConnector`, a subclass of `SharedStorageConnector`, that simulates missing or corrupted external KV blocks by failing to load blocks for the first decode request.
|
||||
- `run.sh` – orchestrates the test: runs the prefill stage, then three decode stages:
|
||||
1. Normal decode (baseline).
|
||||
2. Decode with simulated sync KV load failure.
|
||||
3. Decode with simulated async KV load failure.
|
||||
|
||||
Finally, it compares the output of the baseline with the recovered outputs to verify correctness.
|
||||
|
||||
## How It Works
|
||||
|
||||
- The test dynamically loads `RogueSharedStorageConnector` via `KVTransferConfig.kv_connector_module_path`, enabling controlled simulation of load failures without modifying the original connector.
|
||||
- The decode stages that simulate failure are expected to trigger recovery logic in vLLM, resulting in the same output as the baseline decode.
|
||||
- If recovery fails, the script prints a unified diff of the output mismatch and exits with error.
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
./run.sh
|
@ -0,0 +1,85 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
|
||||
|
||||
def read_prompts():
|
||||
"""Read prompts from prefill_output.txt"""
|
||||
prompts = []
|
||||
try:
|
||||
with open("prefill_output.txt") as f:
|
||||
for line in f:
|
||||
prompts.append(line.strip())
|
||||
print(f"Loaded {len(prompts)} prompts from prefill_output.txt")
|
||||
return prompts
|
||||
except FileNotFoundError:
|
||||
print("Error: prefill_output.txt file not found")
|
||||
exit(-1)
|
||||
|
||||
|
||||
def main():
|
||||
prompts = read_prompts()
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--simulate-failure", action="store_true", help="Simulate KV load failure."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--async-load", action="store_true", help="Simulate async KV load"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.simulate_failure:
|
||||
ktc = KVTransferConfig(
|
||||
kv_connector="RogueSharedStorageConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"shared_storage_path": "local_storage",
|
||||
"async_load": args.async_load,
|
||||
},
|
||||
kv_connector_module_path="rogue_shared_storage_connector",
|
||||
)
|
||||
out_file = (
|
||||
"async_decode_recovered_output.txt"
|
||||
if args.async_load
|
||||
else "sync_decode_recovered_output.txt"
|
||||
)
|
||||
else:
|
||||
ktc = KVTransferConfig(
|
||||
kv_connector="SharedStorageConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"shared_storage_path": "local_storage",
|
||||
},
|
||||
)
|
||||
out_file = "decode_output.txt"
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.8,
|
||||
max_num_batched_tokens=64,
|
||||
max_num_seqs=16,
|
||||
kv_transfer_config=ktc,
|
||||
)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
sep_str = "-" * 30
|
||||
with open(out_file, "w", encoding="utf-8") as f:
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
out_str = f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}"
|
||||
print(out_str)
|
||||
print(sep_str)
|
||||
f.write(out_str)
|
||||
f.write(sep_str)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,58 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
|
||||
|
||||
def read_prompts():
|
||||
context = "Hi " * 1000
|
||||
context2 = "Hey " * 500
|
||||
return [
|
||||
context + "Hello, my name is",
|
||||
context + "The capital of France is",
|
||||
context2 + "Your name is",
|
||||
context2 + "The capital of China is",
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
prompts = read_prompts()
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.8,
|
||||
kv_transfer_config=KVTransferConfig(
|
||||
kv_connector="SharedStorageConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
||||
),
|
||||
) # , max_model_len=2048, max_num_batched_tokens=2048)
|
||||
|
||||
# 1ST generation (prefill instance)
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
)
|
||||
|
||||
new_prompts = []
|
||||
print("-" * 30)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
new_prompts.append(prompt + generated_text)
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 30)
|
||||
|
||||
# Write new_prompts to prefill_output.txt
|
||||
with open("prefill_output.txt", "w") as f:
|
||||
for prompt in new_prompts:
|
||||
f.write(prompt + "\n")
|
||||
print(f"Saved {len(new_prompts)} prompts to prefill_output.txt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,145 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import (
|
||||
SharedStorageConnector,
|
||||
SharedStorageConnectorMetadata,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
logger = logging.getLogger()
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RogueSharedStorageConnectorMetadata(SharedStorageConnectorMetadata):
|
||||
req_to_block_ids: dict[str, set[int]] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_base(cls, base: SharedStorageConnectorMetadata):
|
||||
return cls(requests=base.requests)
|
||||
|
||||
|
||||
class RogueSharedStorageConnector(SharedStorageConnector):
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
self._async_load = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"async_load", False
|
||||
)
|
||||
self._invalid_block_ids: set = None
|
||||
self._seen_requests: set = set()
|
||||
self._req_to_block_ids: dict[str, list[int]] = dict()
|
||||
|
||||
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
|
||||
assert isinstance(connector_metadata, RogueSharedStorageConnectorMetadata)
|
||||
index, failed_request = next(
|
||||
(
|
||||
(i, x)
|
||||
for i, x in enumerate(connector_metadata.requests)
|
||||
if not x.is_store
|
||||
),
|
||||
(None, None),
|
||||
)
|
||||
if index is not None:
|
||||
del connector_metadata.requests[index]
|
||||
self._invalid_block_ids = set(
|
||||
(
|
||||
failed_request.slot_mapping[:: self._block_size] // self._block_size
|
||||
).tolist()
|
||||
)
|
||||
logger.info(
|
||||
"Simulating failure to load all KV blocks for the "
|
||||
"first load request. Total blocks: %d",
|
||||
len(self._invalid_block_ids),
|
||||
)
|
||||
super().bind_connector_metadata(connector_metadata)
|
||||
|
||||
def clear_connector_metadata(self) -> None:
|
||||
self._invalid_block_ids = None
|
||||
super().clear_connector_metadata()
|
||||
|
||||
def start_load_kv(self, forward_context: ForwardContext, **kwargs) -> None:
|
||||
if self._async_load and forward_context.attn_metadata is None:
|
||||
# Bypass sanity check in super().start_load_kv
|
||||
forward_context.attn_metadata = "None"
|
||||
|
||||
super().start_load_kv(forward_context, **kwargs)
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
if self._async_load:
|
||||
meta = self._get_connector_metadata()
|
||||
assert isinstance(meta, RogueSharedStorageConnectorMetadata)
|
||||
if meta.req_to_block_ids:
|
||||
return None, set(meta.req_to_block_ids)
|
||||
|
||||
return None, None
|
||||
|
||||
def get_block_ids_with_load_errors(self) -> set[int]:
|
||||
return self._invalid_block_ids
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: Request,
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
if request.request_id in self._seen_requests:
|
||||
return 0, False
|
||||
|
||||
self._seen_requests.add(request.request_id)
|
||||
|
||||
num_tokens, _ = super().get_num_new_matched_tokens(request, num_computed_tokens)
|
||||
return num_tokens, self._async_load and num_tokens > 0
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
|
||||
):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
|
||||
If blocks were allocated, add to _requests_need_load,
|
||||
such that we load the KVs in the next forward pass.
|
||||
"""
|
||||
super().update_state_after_alloc(request, blocks, num_external_tokens)
|
||||
|
||||
if num_external_tokens > 0:
|
||||
self._req_to_block_ids[request.request_id] = blocks.get_block_ids()[0]
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> KVConnectorMetadata:
|
||||
if not self._async_load:
|
||||
base = super().build_connector_meta(scheduler_output)
|
||||
meta = RogueSharedStorageConnectorMetadata.from_base(base)
|
||||
else:
|
||||
meta = RogueSharedStorageConnectorMetadata()
|
||||
if self._requests_need_load:
|
||||
for req_id, request in self._requests_need_load.items():
|
||||
meta.add_request(
|
||||
token_ids=request.prompt_token_ids,
|
||||
block_ids=self._req_to_block_ids[req_id],
|
||||
block_size=self._block_size,
|
||||
is_store=False,
|
||||
mm_hashes=[],
|
||||
)
|
||||
# Clear state
|
||||
self._requests_need_load.clear()
|
||||
meta.req_to_block_ids = self._req_to_block_ids
|
||||
self._req_to_block_ids = dict()
|
||||
return meta
|
33
examples/offline_inference/kv_load_failure_recovery/run.sh
Executable file
33
examples/offline_inference/kv_load_failure_recovery/run.sh
Executable file
@ -0,0 +1,33 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Constants
|
||||
SHARED_STORAGE_DIR="local_storage"
|
||||
PREFILL_OUTPUT="prefill_output.txt"
|
||||
DECODE_OUTPUT="decode_output.txt"
|
||||
SYNC_DECODE_RECOVERED_OUTPUT="sync_decode_recovered_output.txt"
|
||||
ASYNC_DECODE_RECOVERED_OUTPUT="async_decode_recovered_output.txt"
|
||||
|
||||
# Cleanup
|
||||
rm -rf "$SHARED_STORAGE_DIR"
|
||||
rm -f "$PREFILL_OUTPUT" "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"
|
||||
|
||||
# Run inference examples
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 prefill_example.py
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py --simulate-failure --async-load
|
||||
|
||||
# Compare outputs
|
||||
if ! cmp -s "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT"; then
|
||||
echo "❌ Outputs differ: sync recovery failed."
|
||||
diff -u "$DECODE_OUTPUT" "$SYNC_DECODE_RECOVERED_OUTPUT"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! cmp -s "$DECODE_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"; then
|
||||
echo "❌ Outputs differ: async recovery failed."
|
||||
diff -u "$DECODE_OUTPUT" "$ASYNC_DECODE_RECOVERED_OUTPUT"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✅ Outputs match: recovery successful."
|
341
tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py
Normal file
341
tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py
Normal file
@ -0,0 +1,341 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
from .utils import (create_model_runner_output, create_request,
|
||||
create_scheduler, create_vllm_config)
|
||||
|
||||
|
||||
def _make_get_num_new_matched_tokens(
|
||||
req_num_new_matched_tokens: dict[str, int],
|
||||
async_load,
|
||||
) -> Callable[[Request, int], tuple[int, bool]]:
|
||||
|
||||
def get_num_new_matched_tokens(request: Request,
|
||||
_: int) -> tuple[int, bool]:
|
||||
value = req_num_new_matched_tokens.get(request.request_id, 0)
|
||||
return value, async_load
|
||||
|
||||
return get_num_new_matched_tokens
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scheduler():
|
||||
vllm_config = create_vllm_config()
|
||||
return create_scheduler(vllm_config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_prompt_blocks,"
|
||||
"num_external_computed_blocks,"
|
||||
"invalid_block_idxs",
|
||||
[
|
||||
(100, 99, {0, 98}),
|
||||
(100, 99, {50, 98}),
|
||||
(100, 99, {98}),
|
||||
],
|
||||
)
|
||||
def test_async_load_failure(
|
||||
scheduler: Scheduler,
|
||||
num_prompt_blocks: int,
|
||||
num_external_computed_blocks: int,
|
||||
invalid_block_idxs: set[int],
|
||||
):
|
||||
assert num_prompt_blocks >= num_external_computed_blocks
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
|
||||
num_external_computed_tokens = (num_external_computed_blocks *
|
||||
scheduler.block_size)
|
||||
|
||||
request1 = create_request(num_tokens=num_prompt_tokens)
|
||||
scheduler.add_request(request=request1)
|
||||
request2 = create_request(num_tokens=num_prompt_tokens)
|
||||
scheduler.add_request(request=request2)
|
||||
request3 = create_request(num_tokens=num_prompt_tokens)
|
||||
scheduler.add_request(request=request3)
|
||||
|
||||
# Mock KV connector method.
|
||||
# req_id -> num_external_computed_tokens
|
||||
req_num_new_matched_tokens = {
|
||||
request1.request_id: num_external_computed_tokens,
|
||||
request2.request_id: num_external_computed_tokens,
|
||||
request3.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
scheduler.connector = Mock()
|
||||
scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens,
|
||||
async_load=True))
|
||||
scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
assert len(scheduler.waiting) == 3
|
||||
for request in scheduler.waiting:
|
||||
assert request.num_computed_tokens == 0
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
|
||||
|
||||
# Simulate a failure in loading some of request2 blocks.
|
||||
(req2_block_ids, ) = scheduler.kv_cache_manager.get_block_ids(
|
||||
request2.request_id)
|
||||
invalid_block_ids = {req2_block_ids[i] for i in invalid_block_idxs}
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[],
|
||||
finished_recving={request1.request_id, request3.request_id},
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True)
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
min_invalid_block_idx = min(invalid_block_idxs)
|
||||
|
||||
assert len(scheduler.waiting) == 3
|
||||
for request in scheduler.waiting:
|
||||
if request.request_id == request2.request_id:
|
||||
assert request.num_computed_tokens == (min_invalid_block_idx *
|
||||
scheduler.block_size)
|
||||
else:
|
||||
assert request.num_computed_tokens == 0
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert scheduler.failed_recving_kv_req_ids == {request2.request_id}
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_prompt_blocks,"
|
||||
"num_external_computed_blocks,"
|
||||
"invalid_block_idxs",
|
||||
[
|
||||
(100, 99, {0, 98}),
|
||||
(100, 99, {50, 98}),
|
||||
(100, 99, {98}),
|
||||
],
|
||||
)
|
||||
def test_sync_load_failure(
|
||||
scheduler: Scheduler,
|
||||
num_prompt_blocks: int,
|
||||
num_external_computed_blocks: int,
|
||||
invalid_block_idxs: set[int],
|
||||
):
|
||||
assert num_prompt_blocks >= num_external_computed_blocks
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
|
||||
num_external_computed_tokens = (num_external_computed_blocks *
|
||||
scheduler.block_size)
|
||||
|
||||
request1 = create_request(num_tokens=num_prompt_tokens)
|
||||
scheduler.add_request(request=request1)
|
||||
request2 = create_request(num_tokens=num_prompt_tokens)
|
||||
scheduler.add_request(request=request2)
|
||||
request3 = create_request(num_tokens=num_prompt_tokens)
|
||||
scheduler.add_request(request=request3)
|
||||
|
||||
# Mock KV connector method.
|
||||
# req_id -> num_external_computed_tokens
|
||||
req_num_new_matched_tokens = {
|
||||
request1.request_id: num_external_computed_tokens,
|
||||
request2.request_id: num_external_computed_tokens,
|
||||
request3.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
scheduler.connector = Mock()
|
||||
scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens,
|
||||
async_load=False))
|
||||
scheduler.connector.request_finished.return_value = (False, None)
|
||||
scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
# req_id -> num_computed_tokens
|
||||
expected_computed_tokens = {
|
||||
request1.request_id: num_external_computed_tokens,
|
||||
request2.request_id: num_external_computed_tokens,
|
||||
request3.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
assert len(scheduler.running) == 3
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 3
|
||||
for request in scheduler_output.scheduled_new_reqs:
|
||||
assert request.num_computed_tokens == expected_computed_tokens[
|
||||
request.req_id]
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
|
||||
|
||||
# Simulate a failure in loading some of request2 blocks.
|
||||
req2_block_ids = scheduler_output.scheduled_new_reqs[1].block_ids[0]
|
||||
invalid_block_ids = {req2_block_ids[i] for i in invalid_block_idxs}
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request1, request2, request3],
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True)
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
assert len(scheduler.running) == 1
|
||||
assert scheduler.running[0].request_id == request2.request_id
|
||||
assert scheduler.running[0].num_computed_tokens == (
|
||||
min(invalid_block_idxs) * scheduler.block_size)
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
|
||||
assert scheduler.connector.request_finished.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_prompt_blocks,"
|
||||
"num_external_computed_blocks,"
|
||||
"num_common_prefix_blocks,"
|
||||
"invalid_block_idxs",
|
||||
[
|
||||
(100, 99, 50, {0, 49}),
|
||||
(100, 99, 50, {25, 49}),
|
||||
(100, 99, 50, {49}),
|
||||
],
|
||||
)
|
||||
def test_sync_load_failure_with_shared_blocks(
|
||||
scheduler: Scheduler,
|
||||
num_prompt_blocks: int,
|
||||
num_external_computed_blocks: int,
|
||||
num_common_prefix_blocks: int,
|
||||
invalid_block_idxs: set[int],
|
||||
):
|
||||
assert (num_prompt_blocks >= num_external_computed_blocks >=
|
||||
num_common_prefix_blocks)
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
|
||||
num_external_computed_tokens = (num_external_computed_blocks *
|
||||
scheduler.block_size)
|
||||
common_prefix_len = num_common_prefix_blocks * scheduler.block_size
|
||||
|
||||
request1 = create_request(num_tokens=num_prompt_tokens,
|
||||
common_prefix_len=common_prefix_len)
|
||||
scheduler.add_request(request=request1)
|
||||
request2 = create_request(num_tokens=num_prompt_tokens,
|
||||
common_prefix_len=common_prefix_len)
|
||||
scheduler.add_request(request=request2)
|
||||
|
||||
# Mock KV connector method.
|
||||
# req_id -> num_external_computed_tokens
|
||||
req_num_new_matched_tokens = {
|
||||
request1.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
scheduler.connector = Mock()
|
||||
scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens,
|
||||
async_load=False))
|
||||
scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
# req_id -> num_computed_tokens
|
||||
expected_computed_tokens = {
|
||||
request1.request_id: num_external_computed_tokens,
|
||||
request2.request_id: common_prefix_len,
|
||||
}
|
||||
|
||||
assert len(scheduler.running) == 2
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 2
|
||||
for request in scheduler_output.scheduled_new_reqs:
|
||||
assert request.num_computed_tokens == expected_computed_tokens[
|
||||
request.req_id]
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 2
|
||||
|
||||
# Simulate a failure in loading some of the shared blocks.
|
||||
req1_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
|
||||
invalid_block_ids = {req1_block_ids[i] for i in invalid_block_idxs}
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request1, request2],
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True)
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# req_id -> num_computed_tokens
|
||||
# all the common prefix blocks will be computed by request1
|
||||
expected_computed_tokens = {
|
||||
request1.request_id: min(invalid_block_idxs) * scheduler.block_size,
|
||||
request2.request_id: common_prefix_len,
|
||||
}
|
||||
|
||||
assert len(scheduler.running) == 2
|
||||
for request in scheduler.running:
|
||||
assert request.num_computed_tokens == expected_computed_tokens[
|
||||
request.request_id]
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_prompt_blocks,"
|
||||
"num_external_computed_blocks,"
|
||||
"invalid_block_idxs",
|
||||
[
|
||||
(100, 99, {0, 50, 98}),
|
||||
(100, 99, {98, 50, 0}),
|
||||
],
|
||||
)
|
||||
def test_async_progressive_load_failure(
|
||||
scheduler: Scheduler,
|
||||
num_prompt_blocks: int,
|
||||
num_external_computed_blocks: int,
|
||||
invalid_block_idxs: set[int],
|
||||
):
|
||||
assert num_prompt_blocks >= num_external_computed_blocks
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
|
||||
num_external_computed_tokens = (num_external_computed_blocks *
|
||||
scheduler.block_size)
|
||||
|
||||
request = create_request(num_tokens=num_prompt_tokens)
|
||||
scheduler.add_request(request=request)
|
||||
|
||||
# Mock KV connector method.
|
||||
# req_id -> num_external_computed_tokens
|
||||
req_num_new_matched_tokens = {
|
||||
request.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
scheduler.connector = Mock()
|
||||
scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens,
|
||||
async_load=True))
|
||||
scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert scheduler.waiting.peek_request().request_id == request.request_id
|
||||
assert request.num_computed_tokens == 0
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 1
|
||||
|
||||
min_invalid_block_idx = max(invalid_block_idxs) + 1
|
||||
# Simulate failures when progressively loading request blocks.
|
||||
for invalid_block_idx in invalid_block_idxs:
|
||||
(req_block_ids, ) = scheduler.kv_cache_manager.get_block_ids(
|
||||
request.request_id)
|
||||
invalid_block_ids = {req_block_ids[invalid_block_idx]}
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[],
|
||||
finished_recving=set(),
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True)
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
min_invalid_block_idx = min(min_invalid_block_idx, invalid_block_idx)
|
||||
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert scheduler.waiting.peek_request(
|
||||
).request_id == request.request_id
|
||||
assert request.num_computed_tokens == (min_invalid_block_idx *
|
||||
scheduler.block_size)
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert scheduler.failed_recving_kv_req_ids == {request.request_id}
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 1
|
@ -281,8 +281,8 @@ class RequestRunner:
|
||||
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=self.scheduler.running,
|
||||
finished_sending=list(finished_sending),
|
||||
finished_recving=list(finished_recving),
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
token_id=token_id)
|
||||
|
||||
if self.scheduler.running:
|
||||
|
@ -15,26 +15,26 @@ class DummyModelRunnerOutput(ModelRunnerOutput):
|
||||
|
||||
def __init__(self,
|
||||
finished_sending: Optional[set[str]] = None,
|
||||
finished_recving: Optional[set[str]] = None):
|
||||
finished_recving: Optional[set[str]] = None,
|
||||
invalid_block_ids: Optional[set[int]] = None):
|
||||
self.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
)
|
||||
invalid_block_ids=invalid_block_ids or set())
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"DummyModelRunnerOutput("
|
||||
f"finished_sending={self.kv_connector_output.finished_sending},"
|
||||
f"finished_recving={self.kv_connector_output.finished_recving})")
|
||||
f"finished_recving={self.kv_connector_output.finished_recving})"
|
||||
f"invalid_block_ids={self.kv_connector_output.invalid_block_ids})")
|
||||
|
||||
|
||||
def test_aggregate_workers_output():
|
||||
aggregator = KVOutputAggregator(world_size=2)
|
||||
|
||||
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
|
||||
finished_recving={'req2'})
|
||||
output2 = DummyModelRunnerOutput(finished_sending=None,
|
||||
finished_recving=None)
|
||||
output1 = DummyModelRunnerOutput()
|
||||
output2 = DummyModelRunnerOutput()
|
||||
|
||||
aggregated = aggregator.aggregate([output1, output2])
|
||||
|
||||
@ -42,11 +42,22 @@ def test_aggregate_workers_output():
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending is None
|
||||
assert aggregated.finished_recving is None
|
||||
assert not aggregated.invalid_block_ids
|
||||
|
||||
output1 = DummyModelRunnerOutput(finished_sending=None,
|
||||
finished_recving=None)
|
||||
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
|
||||
finished_recving=None)
|
||||
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
|
||||
finished_recving={'req2'})
|
||||
output2 = DummyModelRunnerOutput(invalid_block_ids={1})
|
||||
|
||||
aggregated = aggregator.aggregate([output1, output2])
|
||||
|
||||
assert aggregated is output1
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending is None
|
||||
assert aggregated.finished_recving is None
|
||||
assert aggregated.invalid_block_ids == {1}
|
||||
|
||||
output1 = DummyModelRunnerOutput(invalid_block_ids={2})
|
||||
output2 = DummyModelRunnerOutput(finished_sending={'req1'})
|
||||
|
||||
aggregated = aggregator.aggregate([output1, output2])
|
||||
|
||||
@ -54,11 +65,11 @@ def test_aggregate_workers_output():
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending == {'req1'}
|
||||
assert aggregated.finished_recving is None
|
||||
assert aggregated.invalid_block_ids == {2}
|
||||
|
||||
output1 = DummyModelRunnerOutput(finished_sending=None,
|
||||
finished_recving=None)
|
||||
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
|
||||
finished_recving={'req2'})
|
||||
output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4})
|
||||
output2 = DummyModelRunnerOutput(finished_recving={'req2'},
|
||||
invalid_block_ids={4, 5})
|
||||
|
||||
aggregated = aggregator.aggregate([output1, output2])
|
||||
|
||||
@ -66,6 +77,7 @@ def test_aggregate_workers_output():
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending is None
|
||||
assert aggregated.finished_recving == {'req2'}
|
||||
assert aggregated.invalid_block_ids == {3, 4, 5}
|
||||
|
||||
|
||||
def test_async_aggregate_workers_output():
|
||||
@ -75,10 +87,8 @@ def test_async_aggregate_workers_output():
|
||||
future2: Future[DummyModelRunnerOutput] = Future()
|
||||
result_future = aggregator.async_aggregate([future1, future2])
|
||||
|
||||
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
|
||||
finished_recving={'req2'})
|
||||
output2 = DummyModelRunnerOutput(finished_sending=None,
|
||||
finished_recving=None)
|
||||
output1 = DummyModelRunnerOutput()
|
||||
output2 = DummyModelRunnerOutput()
|
||||
future1.set_result(output1)
|
||||
future2.set_result(output2)
|
||||
|
||||
@ -88,15 +98,32 @@ def test_async_aggregate_workers_output():
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending is None
|
||||
assert aggregated.finished_recving is None
|
||||
assert not aggregated.invalid_block_ids
|
||||
|
||||
future1 = Future()
|
||||
future2 = Future()
|
||||
result_future = aggregator.async_aggregate([future1, future2])
|
||||
|
||||
output1 = DummyModelRunnerOutput(finished_sending=None,
|
||||
finished_recving=None)
|
||||
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
|
||||
finished_recving=None)
|
||||
output1 = DummyModelRunnerOutput(finished_sending={'req1'},
|
||||
finished_recving={'req2'})
|
||||
output2 = DummyModelRunnerOutput(invalid_block_ids={1})
|
||||
future1.set_result(output1)
|
||||
future2.set_result(output2)
|
||||
|
||||
assert result_future.done()
|
||||
aggregated = result_future.result()
|
||||
assert aggregated is output1
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending is None
|
||||
assert aggregated.finished_recving is None
|
||||
assert aggregated.invalid_block_ids == {1}
|
||||
|
||||
future1 = Future()
|
||||
future2 = Future()
|
||||
result_future = aggregator.async_aggregate([future1, future2])
|
||||
|
||||
output1 = DummyModelRunnerOutput(invalid_block_ids={2})
|
||||
output2 = DummyModelRunnerOutput(finished_sending={'req1'})
|
||||
future1.set_result(output1)
|
||||
future2.set_result(output2)
|
||||
|
||||
@ -106,15 +133,15 @@ def test_async_aggregate_workers_output():
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending == {'req1'}
|
||||
assert aggregated.finished_recving is None
|
||||
assert aggregated.invalid_block_ids == {2}
|
||||
|
||||
future1 = Future()
|
||||
future2 = Future()
|
||||
result_future = aggregator.async_aggregate([future1, future2])
|
||||
|
||||
output1 = DummyModelRunnerOutput(finished_sending=None,
|
||||
finished_recving=None)
|
||||
output2 = DummyModelRunnerOutput(finished_sending={'req1'},
|
||||
finished_recving={'req2'})
|
||||
output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4})
|
||||
output2 = DummyModelRunnerOutput(finished_recving={'req2'},
|
||||
invalid_block_ids={4, 5})
|
||||
future1.set_result(output1)
|
||||
future2.set_result(output2)
|
||||
|
||||
@ -124,3 +151,4 @@ def test_async_aggregate_workers_output():
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending is None
|
||||
assert aggregated.finished_recving == {'req2'}
|
||||
assert aggregated.invalid_block_ids == {3, 4, 5}
|
||||
|
@ -92,7 +92,7 @@ def test_basic_lifecycle():
|
||||
# (3b): execute_model()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=[request_id])
|
||||
finished_sending={request_id})
|
||||
|
||||
# (3c): update_from_output()
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
@ -139,7 +139,7 @@ def test_short_prompt_lifecycle():
|
||||
scheduler_output = scheduler.schedule()
|
||||
# Use create_model_runner_output to pass kv_connector_output along
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request], finished_sending=[request.request_id])
|
||||
reqs=[request], finished_sending={request.request_id})
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
@ -195,6 +195,6 @@ def test_prefix_cache_lifecycle():
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=[request_remote.request_id])
|
||||
finished_sending={request_remote.request_id})
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
@ -78,7 +78,7 @@ def test_basic_lifecycle():
|
||||
# (2b): forward(): request finishes recv.
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_recving=[request_id])
|
||||
finished_recving={request_id})
|
||||
|
||||
# (2c): update_from_output():
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
@ -197,7 +197,7 @@ def test_interleaved_lifecycle():
|
||||
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request_local_a, request_local_b],
|
||||
finished_recving=[request_remote.request_id])
|
||||
finished_recving={request_remote.request_id})
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# STEP 5: RECVed KVs are sent to ModelRunner.
|
||||
@ -246,16 +246,16 @@ def test_no_spurious_prefix_caching():
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
common_prefix_len=NUM_TOKENS,
|
||||
do_remote_prefill=True,
|
||||
use_all_1s_for_prompt_tokens=True,
|
||||
)
|
||||
|
||||
request_local = create_request(
|
||||
request_id=2,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
common_prefix_len=NUM_TOKENS,
|
||||
do_remote_prefill=False,
|
||||
use_all_1s_for_prompt_tokens=True,
|
||||
)
|
||||
|
||||
# Schedule the remote prefill request. This should not
|
||||
@ -322,7 +322,7 @@ def test_full_block_prompt():
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_recving=[request_id])
|
||||
finished_recving={request_id})
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert (request_id in scheduler.finished_recving_kv_req_ids)
|
||||
@ -402,7 +402,7 @@ def test_cannot_schedule_after_recv():
|
||||
# Step 3: finish recving (5 blocks in use)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request_normal], finished_recving=[request_remote.request_id])
|
||||
reqs=[request_normal], finished_recving={request_remote.request_id})
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 1
|
||||
@ -516,7 +516,7 @@ def test_cannot_recv():
|
||||
# Step 5: finish recving (5 blocks in use)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[], finished_recving=[request_remote.request_id])
|
||||
reqs=[], finished_recving={request_remote.request_id})
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from itertools import count
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
@ -61,12 +62,15 @@ def create_vllm_config(
|
||||
max_num_seqs: int = 16,
|
||||
max_num_batched_tokens: int = 64,
|
||||
block_size: int = 16,
|
||||
max_model_len: int = 10000,
|
||||
enable_chunked_prefill: bool = True,
|
||||
) -> VllmConfig:
|
||||
"""Initialize VllmConfig For Testing."""
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_model_len=max_num_batched_tokens,
|
||||
max_model_len=max_model_len,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
)
|
||||
model_config = ModelConfig(
|
||||
model=model,
|
||||
@ -117,19 +121,27 @@ def create_scheduler(
|
||||
)
|
||||
|
||||
|
||||
_request_count = count(1)
|
||||
_none_hash_initialized = False
|
||||
|
||||
|
||||
def create_request(request_id: int,
|
||||
num_tokens: int = 10,
|
||||
max_tokens: int = 16,
|
||||
do_remote_decode: bool = False,
|
||||
do_remote_prefill: bool = False,
|
||||
use_all_1s_for_prompt_tokens: bool = False,
|
||||
num_remote_blocks: int = 3,
|
||||
block_size: int = 16,
|
||||
hash_fn: Callable = sha256) -> Request:
|
||||
def create_request(
|
||||
request_id: Optional[int] = None,
|
||||
num_tokens: int = 10,
|
||||
common_prefix_len=0,
|
||||
max_tokens: int = 16,
|
||||
do_remote_decode: bool = False,
|
||||
do_remote_prefill: bool = False,
|
||||
num_remote_blocks: int = 3,
|
||||
block_size: int = 16,
|
||||
hash_fn: Callable = sha256,
|
||||
) -> Request:
|
||||
"""Make dummy request for testing."""
|
||||
assert num_tokens >= common_prefix_len >= 0
|
||||
|
||||
if request_id is None:
|
||||
request_id = next(_request_count)
|
||||
|
||||
global _none_hash_initialized
|
||||
if not _none_hash_initialized:
|
||||
init_none_hash(hash_fn)
|
||||
@ -153,10 +165,9 @@ def create_request(request_id: int,
|
||||
max_tokens = 1 if do_remote_decode else max_tokens
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens)
|
||||
|
||||
if use_all_1s_for_prompt_tokens:
|
||||
prompt_token_ids = [1] * num_tokens
|
||||
else:
|
||||
prompt_token_ids = [i * request_id for i in range(num_tokens)]
|
||||
common_prefix = [1] * common_prefix_len if common_prefix_len > 0 else []
|
||||
suffix = [i * request_id for i in range(num_tokens - common_prefix_len)]
|
||||
prompt_token_ids = common_prefix + suffix
|
||||
|
||||
req = Request(
|
||||
request_id=f"id-{request_id}",
|
||||
@ -173,8 +184,9 @@ def create_request(request_id: int,
|
||||
|
||||
def create_model_runner_output(
|
||||
reqs: list[Request],
|
||||
finished_sending: Optional[list[str]] = None,
|
||||
finished_recving: Optional[list[str]] = None,
|
||||
finished_sending: Optional[set[str]] = None,
|
||||
finished_recving: Optional[set[str]] = None,
|
||||
invalid_block_ids: Optional[set[int]] = None,
|
||||
use_eos: bool = False,
|
||||
token_id: int = 0,
|
||||
) -> ModelRunnerOutput:
|
||||
@ -189,10 +201,11 @@ def create_model_runner_output(
|
||||
sampled_token_ids = [[sampled_token] for _ in req_ids]
|
||||
|
||||
kv_connector_output = None if (
|
||||
finished_sending is None
|
||||
and finished_recving is None) else KVConnectorOutput(
|
||||
finished_sending is None and finished_recving is None
|
||||
and invalid_block_ids is None) else KVConnectorOutput(
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
invalid_block_ids=invalid_block_ids or set(),
|
||||
)
|
||||
|
||||
# Make output data structure.
|
||||
|
@ -250,6 +250,7 @@ def test_update_states_request_resumed(model_runner, dist_init):
|
||||
new_token_ids=[[]],
|
||||
new_block_ids=([[0]], ),
|
||||
num_computed_tokens=[0],
|
||||
num_output_tokens=[0],
|
||||
)
|
||||
|
||||
scheduler_output = SchedulerOutput(
|
||||
|
@ -117,7 +117,7 @@ def get_kv_connector_cache_layout():
|
||||
|
||||
|
||||
class KVOutputAggregator:
|
||||
"""Utility class to aggregate the output of all workers into a single
|
||||
"""Utility class to aggregate the output of all workers into a single
|
||||
output corresponding to Rank 0 for scheduler."""
|
||||
|
||||
def __init__(self, world_size: int):
|
||||
@ -143,6 +143,7 @@ class KVOutputAggregator:
|
||||
finished_sending = set[str]()
|
||||
finished_recving = set[str]()
|
||||
aggregated_kv_connector_stats = None
|
||||
invalid_block_ids = set[int]()
|
||||
for model_runner_output in outputs:
|
||||
output = model_runner_output.kv_connector_output
|
||||
if not output:
|
||||
@ -165,6 +166,8 @@ class KVOutputAggregator:
|
||||
aggregated_kv_connector_stats = \
|
||||
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
|
||||
|
||||
invalid_block_ids |= output.invalid_block_ids
|
||||
|
||||
# select output of the worker specified by output_rank
|
||||
output = outputs[output_rank]
|
||||
|
||||
@ -172,6 +175,7 @@ class KVOutputAggregator:
|
||||
finished_sending=finished_sending or None,
|
||||
finished_recving=finished_recving or None,
|
||||
kv_connector_stats=aggregated_kv_connector_stats or None,
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
)
|
||||
|
||||
return output
|
||||
|
@ -229,6 +229,26 @@ class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
return None, None
|
||||
|
||||
def get_block_ids_with_load_errors(self) -> set[int]:
|
||||
"""
|
||||
Get the set of block IDs that failed to load.
|
||||
|
||||
Returns:
|
||||
Set of block IDs that encountered load errors.
|
||||
Empty set if no load errors occurred.
|
||||
|
||||
Notes:
|
||||
- Applies to both sync- and async-loading requests.
|
||||
- Async loading: failed blocks may be reported in any forward pass
|
||||
up to and including the pass where the request ID is returned by
|
||||
`get_finished()`. Even if failures occur, the request must still
|
||||
be reported via `get_finished()`, and the failed block IDs must
|
||||
appear here no later than that same pass.
|
||||
- Sync loading: failed blocks should be reported in the forward
|
||||
pass in which they are detected.
|
||||
"""
|
||||
return set()
|
||||
|
||||
def shutdown(self):
|
||||
"""
|
||||
Shutdown the connector. This is called when the worker process
|
||||
@ -264,14 +284,21 @@ class KVConnectorBase_V1(ABC):
|
||||
|
||||
Returns:
|
||||
A tuple with the following elements:
|
||||
- An optional number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
- An optional number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
If None, it means that the connector needs more time to
|
||||
determine the number of matched tokens, and the scheduler
|
||||
should query for this request again later.
|
||||
- `True` if external KV cache tokens will be loaded
|
||||
asynchronously (between scheduler steps). Must be
|
||||
'False' if the first element is 0.
|
||||
|
||||
Notes:
|
||||
The connector should only consider the largest prefix of prompt-
|
||||
tokens for which KV cache is actually available at the time of the
|
||||
call. If the cache cannot be loaded for some tokens (e.g., due to
|
||||
connectivity issues or eviction), those tokens must not be taken
|
||||
into account.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
@ -189,6 +189,12 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
|
||||
return finished_sending or None, finished_recving or None
|
||||
|
||||
def get_block_ids_with_load_errors(self) -> set[int]:
|
||||
agg_block_ids: set[int] = set()
|
||||
for c in self._connectors:
|
||||
agg_block_ids |= c.get_block_ids_with_load_errors()
|
||||
return agg_block_ids
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import hashlib
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import safetensors
|
||||
@ -55,10 +55,7 @@ class ReqMeta:
|
||||
|
||||
@dataclass
|
||||
class SharedStorageConnectorMetadata(KVConnectorMetadata):
|
||||
requests: list[ReqMeta]
|
||||
|
||||
def __init__(self):
|
||||
self.requests = []
|
||||
requests: list[ReqMeta] = field(default_factory=list)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
|
@ -211,7 +211,7 @@ class BlockPool:
|
||||
block_size: Number of tokens in each block.
|
||||
kv_cache_group_id: The id of the KV cache group.
|
||||
"""
|
||||
if num_cached_blocks == num_full_blocks:
|
||||
if num_cached_blocks >= num_full_blocks:
|
||||
return
|
||||
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
|
||||
assert len(request.block_hashes) >= num_full_blocks
|
||||
|
@ -101,6 +101,7 @@ class CachedRequestData:
|
||||
new_token_ids: list[list[int]]
|
||||
new_block_ids: list[Optional[tuple[list[int], ...]]]
|
||||
num_computed_tokens: list[int]
|
||||
num_output_tokens: list[int]
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
@ -114,6 +115,7 @@ class CachedRequestData:
|
||||
new_token_ids=[],
|
||||
new_block_ids=[],
|
||||
num_computed_tokens=[],
|
||||
num_output_tokens=[],
|
||||
)
|
||||
|
||||
|
||||
|
@ -133,6 +133,7 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
# KV Connector: requests in process of async KV loading or recving
|
||||
self.finished_recving_kv_req_ids: set[str] = set()
|
||||
self.failed_recving_kv_req_ids: set[str] = set()
|
||||
|
||||
# Encoder-related.
|
||||
# Calculate encoder cache size if applicable
|
||||
@ -671,6 +672,7 @@ class Scheduler(SchedulerInterface):
|
||||
new_token_ids: list[list[int]] = []
|
||||
new_block_ids: list[Optional[tuple[list[int], ...]]] = []
|
||||
num_computed_tokens: list[int] = []
|
||||
num_output_tokens: list[int] = []
|
||||
|
||||
use_connector = self.connector is not None
|
||||
for req in itertools.chain(running_reqs, resumed_reqs):
|
||||
@ -695,6 +697,7 @@ class Scheduler(SchedulerInterface):
|
||||
new_block_ids.append(
|
||||
req_to_new_blocks[req_id].get_block_ids(allow_none=True))
|
||||
num_computed_tokens.append(req.num_computed_tokens)
|
||||
num_output_tokens.append(len(req.output_token_ids))
|
||||
# Because resumed_reqs is usually empty, it is more efficient to do
|
||||
# in-place appending so that we don't need to allocate a new list.
|
||||
resumed_from_preemption = [False] * len(running_reqs)
|
||||
@ -706,6 +709,7 @@ class Scheduler(SchedulerInterface):
|
||||
new_token_ids=new_token_ids,
|
||||
new_block_ids=new_block_ids,
|
||||
num_computed_tokens=num_computed_tokens,
|
||||
num_output_tokens=num_output_tokens,
|
||||
)
|
||||
|
||||
def _try_schedule_encoder_inputs(
|
||||
@ -878,6 +882,14 @@ class Scheduler(SchedulerInterface):
|
||||
kv_connector_stats = (kv_connector_output.kv_connector_stats
|
||||
if kv_connector_output else None)
|
||||
|
||||
failed_kv_load_req_ids = None
|
||||
if kv_connector_output and kv_connector_output.invalid_block_ids:
|
||||
# These blocks contain externally computed tokens that failed to
|
||||
# load. Identify affected requests and adjust their computed token
|
||||
# count to trigger recomputation of the invalid blocks.
|
||||
failed_kv_load_req_ids = self._handle_invalid_blocks(
|
||||
kv_connector_output.invalid_block_ids)
|
||||
|
||||
# NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more,
|
||||
# the below loop can be a performance bottleneck. We should do our best
|
||||
# to avoid expensive operations inside the loop.
|
||||
@ -885,6 +897,9 @@ class Scheduler(SchedulerInterface):
|
||||
stopped_preempted_reqs: set[Request] = set()
|
||||
for req_id, num_tokens_scheduled in num_scheduled_tokens.items():
|
||||
assert num_tokens_scheduled > 0
|
||||
if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids:
|
||||
# Skip requests that were recovered from KV load failure
|
||||
continue
|
||||
request = self.requests.get(req_id)
|
||||
if request is None:
|
||||
# The request is already finished. This can happen if the
|
||||
@ -988,9 +1003,8 @@ class Scheduler(SchedulerInterface):
|
||||
self.waiting.remove_requests(stopped_preempted_reqs)
|
||||
|
||||
# KV Connector: update state for finished KV Transfers.
|
||||
if model_runner_output.kv_connector_output:
|
||||
self._update_from_kv_xfer_finished(
|
||||
model_runner_output.kv_connector_output)
|
||||
if kv_connector_output:
|
||||
self._update_from_kv_xfer_finished(kv_connector_output)
|
||||
|
||||
# Create EngineCoreOutputs for all clients that have requests with
|
||||
# outputs in this step.
|
||||
@ -1252,18 +1266,33 @@ class Scheduler(SchedulerInterface):
|
||||
if request.request_id not in self.finished_recving_kv_req_ids:
|
||||
return False
|
||||
|
||||
# Now that the blocks are ready, actually cache them.
|
||||
(block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id)
|
||||
num_computed_tokens = len(block_ids) * self.block_size
|
||||
# Handle the case where num request tokens less than one block.
|
||||
num_computed_tokens = min(num_computed_tokens, request.num_tokens)
|
||||
if num_computed_tokens == request.num_tokens:
|
||||
num_computed_tokens -= 1
|
||||
# This will cache the blocks iff caching is enabled.
|
||||
self.kv_cache_manager.cache_blocks(request, num_computed_tokens)
|
||||
if request.request_id in self.failed_recving_kv_req_ids:
|
||||
# Request had KV load failures; num_computed_tokens was already
|
||||
# updated in _update_requests_with_invalid_blocks
|
||||
if request.num_computed_tokens:
|
||||
# Cache any valid computed tokens.
|
||||
self.kv_cache_manager.cache_blocks(request,
|
||||
request.num_computed_tokens)
|
||||
else:
|
||||
# No valid computed tokens, release allocated blocks.
|
||||
# There may be a local cache hit on retry.
|
||||
self.kv_cache_manager.free(request)
|
||||
|
||||
# Update the request state for scheduling.
|
||||
request.num_computed_tokens = num_computed_tokens
|
||||
self.failed_recving_kv_req_ids.remove(request.request_id)
|
||||
else:
|
||||
# Now that the blocks are ready, actually cache them.
|
||||
(block_ids, ) = self.kv_cache_manager.get_block_ids(
|
||||
request.request_id)
|
||||
num_computed_tokens = len(block_ids) * self.block_size
|
||||
# Handle the case where num request tokens less than one block.
|
||||
num_computed_tokens = min(num_computed_tokens, request.num_tokens)
|
||||
if num_computed_tokens == request.num_tokens:
|
||||
num_computed_tokens -= 1
|
||||
# This will cache the blocks iff caching is enabled.
|
||||
self.kv_cache_manager.cache_blocks(request, num_computed_tokens)
|
||||
|
||||
# Update the request state for scheduling.
|
||||
request.num_computed_tokens = num_computed_tokens
|
||||
|
||||
# Return that we are ready.
|
||||
self.finished_recving_kv_req_ids.remove(request.request_id)
|
||||
@ -1296,3 +1325,134 @@ class Scheduler(SchedulerInterface):
|
||||
"but the request is already freed.", req_id)
|
||||
else:
|
||||
self._free_blocks(self.requests[req_id])
|
||||
|
||||
def _update_requests_with_invalid_blocks(
|
||||
self, requests: Iterable[Request],
|
||||
invalid_block_ids: set[int]) -> tuple[set[str], int]:
|
||||
"""
|
||||
Identify and update requests affected by invalid KV cache blocks.
|
||||
|
||||
This method scans the given requests, detects those with invalid blocks
|
||||
and adjusts their `num_computed_tokens` to the longest valid prefix.
|
||||
For observability, it also accumulates the total number of tokens that
|
||||
will need to be recomputed across all affected requests.
|
||||
|
||||
Args:
|
||||
requests: The set of requests to scan for invalid blocks.
|
||||
invalid_block_ids: IDs of invalid blocks.
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- affected_req_ids (set[str]): IDs of requests impacted by
|
||||
invalid blocks.
|
||||
- total_affected_tokens (int): Total number of tokens that must
|
||||
be recomputed across all affected requests (for observability).
|
||||
"""
|
||||
affected_req_ids: set[str] = set()
|
||||
total_affected_tokens = 0
|
||||
# If a block is invalid and shared by multiple requests in the batch,
|
||||
# these requests must be rescheduled, but only the first will recompute
|
||||
# it. This set tracks blocks already marked for recomputation.
|
||||
marked_invalid_block_ids: set[int] = set()
|
||||
for request in requests:
|
||||
is_affected = False
|
||||
marked_invalid_block = False
|
||||
req_id = request.request_id
|
||||
# TODO (davidb): add support for hybrid memory allocator
|
||||
(req_block_ids, ) = self.kv_cache_manager.get_block_ids(req_id)
|
||||
# We iterate only over blocks that may contain externally computed
|
||||
# tokens
|
||||
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
||||
# Async loading. If num_computed_tokens is set it implies we
|
||||
# already processed some block failures for it in a prior step
|
||||
req_num_computed_tokens = (
|
||||
request.num_computed_tokens if req_id
|
||||
in self.failed_recving_kv_req_ids else len(req_block_ids) *
|
||||
self.block_size)
|
||||
else:
|
||||
# Sync loading. num_computed_tokens includes new tokens
|
||||
req_num_computed_tokens = request.num_cached_tokens
|
||||
|
||||
req_num_computed_blocks = (req_num_computed_tokens +
|
||||
self.block_size - 1) // self.block_size
|
||||
for idx, block_id in zip(range(req_num_computed_blocks),
|
||||
req_block_ids):
|
||||
|
||||
if block_id not in invalid_block_ids:
|
||||
continue
|
||||
|
||||
is_affected = True
|
||||
|
||||
if block_id in marked_invalid_block_ids:
|
||||
# This invalid block is shared with a previous request
|
||||
# and was already marked for recomputation.
|
||||
# This means this request can still consider this block
|
||||
# as computed when rescheduled.
|
||||
# Currently this only applies to sync loading; Async
|
||||
# loading does not yet support block sharing
|
||||
continue
|
||||
|
||||
marked_invalid_block_ids.add(block_id)
|
||||
|
||||
if marked_invalid_block:
|
||||
# This request has already marked an invalid block for
|
||||
# recomputation and updated its num_computed_tokens.
|
||||
continue
|
||||
|
||||
marked_invalid_block = True
|
||||
# Truncate the computed tokens at the first failed block
|
||||
request.num_computed_tokens = idx * self.block_size
|
||||
total_affected_tokens += (req_num_computed_tokens -
|
||||
request.num_computed_tokens)
|
||||
|
||||
if is_affected:
|
||||
if not marked_invalid_block:
|
||||
# All invalid blocks of this request are shared with
|
||||
# previous requests and will be recomputed by them.
|
||||
# Revert to considering only cached tokens as computed.
|
||||
# Currently this only applies to sync loading; Async
|
||||
# loading does not yet support block sharing
|
||||
total_affected_tokens += (request.num_computed_tokens -
|
||||
request.num_cached_tokens)
|
||||
request.num_computed_tokens = request.num_cached_tokens
|
||||
|
||||
affected_req_ids.add(request.request_id)
|
||||
|
||||
return (affected_req_ids, total_affected_tokens)
|
||||
|
||||
def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]:
|
||||
total_requests_to_reschedule = 0
|
||||
total_tokens_to_reschedule = 0
|
||||
|
||||
# --- Handle async KV loads (WAITING_FOR_REMOTE_KVS) ---
|
||||
async_load_reqs = (
|
||||
req for req in self.waiting
|
||||
if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS)
|
||||
async_affected_req_ids, num_tokens_to_reschedule = (
|
||||
self._update_requests_with_invalid_blocks(async_load_reqs,
|
||||
invalid_block_ids))
|
||||
|
||||
total_requests_to_reschedule += len(async_affected_req_ids)
|
||||
total_tokens_to_reschedule += num_tokens_to_reschedule
|
||||
|
||||
# Mark requests with async KV load failures; they will be rescheduled
|
||||
# once loading completes
|
||||
self.failed_recving_kv_req_ids |= async_affected_req_ids
|
||||
|
||||
# --- Handle sync KV loads (running requests) ---
|
||||
sync_affected_req_ids, num_tokens_to_reschedule = (
|
||||
self._update_requests_with_invalid_blocks(self.running,
|
||||
invalid_block_ids))
|
||||
|
||||
total_requests_to_reschedule += len(sync_affected_req_ids)
|
||||
total_tokens_to_reschedule += num_tokens_to_reschedule
|
||||
|
||||
if total_requests_to_reschedule:
|
||||
logger.warning(
|
||||
"Recovered from KV load failure: "
|
||||
"%d request(s) rescheduled (%d tokens affected).",
|
||||
total_requests_to_reschedule, total_tokens_to_reschedule)
|
||||
|
||||
# Return the IDs of affected running requests to skip in
|
||||
# update_from_output.
|
||||
return sync_affected_req_ids
|
||||
|
@ -142,6 +142,9 @@ class SingleTypeKVCacheManager(ABC):
|
||||
num_cached_blocks = self.num_cached_block[request.request_id]
|
||||
num_full_blocks = num_tokens // self.block_size
|
||||
|
||||
if num_cached_blocks >= num_full_blocks:
|
||||
return
|
||||
|
||||
self.block_pool.cache_full_blocks(
|
||||
request=request,
|
||||
blocks=self.req_to_blocks[request.request_id],
|
||||
|
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, NamedTuple, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -87,10 +87,13 @@ class KVConnectorOutput:
|
||||
finished_sending: Optional[set[str]] = None
|
||||
finished_recving: Optional[set[str]] = None
|
||||
kv_connector_stats: Optional["KVConnectorStats"] = None
|
||||
# IDs of externally computed KV blocks that failed to load.
|
||||
# Requests referencing these blocks should be rescheduled to recompute them.
|
||||
invalid_block_ids: set[int] = field(default_factory=set)
|
||||
|
||||
def is_empty(self):
|
||||
return (not self.finished_sending and not self.finished_recving
|
||||
and not self.kv_connector_stats)
|
||||
and not self.kv_connector_stats and not self.invalid_block_ids)
|
||||
|
||||
|
||||
# ModelRunnerOutput is serialized and sent to the scheduler process.
|
||||
|
@ -634,8 +634,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_computed_tokens = req_data.num_computed_tokens[i]
|
||||
new_block_ids = req_data.new_block_ids[i]
|
||||
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
||||
num_output_tokens = req_data.num_output_tokens[i]
|
||||
|
||||
# Update the cached states.
|
||||
|
||||
req_state.num_computed_tokens = num_computed_tokens
|
||||
|
||||
if not is_last_rank:
|
||||
@ -653,6 +655,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
elif num_new_tokens > 0:
|
||||
req_state.output_token_ids.extend(
|
||||
new_token_ids[-num_new_tokens:])
|
||||
elif num_output_tokens < len(req_state.output_token_ids):
|
||||
# Some output tokens were discarded due to a sync-KV-load
|
||||
# failure. Align the cached state.
|
||||
del req_state.output_token_ids[num_output_tokens:]
|
||||
|
||||
req_index = self.input_batch.req_id_to_index.get(req_id)
|
||||
if req_index is not None:
|
||||
old_end_idx = self.input_batch.num_tokens_no_spec[
|
||||
req_index]
|
||||
end_idx = self.input_batch.num_prompt_tokens[
|
||||
req_index] + num_output_tokens
|
||||
self.input_batch.num_tokens[req_index] = end_idx
|
||||
self.input_batch.num_tokens_no_spec[req_index] = end_idx
|
||||
self.input_batch.is_token_ids[req_index,
|
||||
end_idx:old_end_idx] = False
|
||||
|
||||
# Update the block IDs.
|
||||
if not resumed_from_preemption:
|
||||
|
@ -464,8 +464,7 @@ class Worker(WorkerBase):
|
||||
|
||||
# In case of PP with kv transfer, we need to pass through the
|
||||
# kv_connector_output
|
||||
if (not kv_connector_output.finished_sending
|
||||
and not kv_connector_output.finished_recving):
|
||||
if kv_connector_output.is_empty():
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
|
@ -75,8 +75,7 @@ class KVConnectorModelRunnerMixin:
|
||||
scheduler_output, wait_for_save=False) as kv_connector_output:
|
||||
pass
|
||||
|
||||
if (not kv_connector_output.finished_sending
|
||||
and not kv_connector_output.finished_recving):
|
||||
if kv_connector_output.is_empty():
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
@ -120,6 +119,8 @@ class KVConnectorModelRunnerMixin:
|
||||
|
||||
output.finished_sending, output.finished_recving = (
|
||||
kv_connector.get_finished(scheduler_output.finished_req_ids))
|
||||
output.invalid_block_ids = (
|
||||
kv_connector.get_block_ids_with_load_errors())
|
||||
|
||||
output.kv_connector_stats = KVConnectorModelRunnerMixin.\
|
||||
get_kv_connector_stats()
|
||||
|
Reference in New Issue
Block a user