Revert "[V1] Exception Handling when Loading KV Cache from Remote Store" (#21778)

Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
This commit is contained in:
Kuntai Du
2025-07-28 13:15:18 -07:00
committed by GitHub
parent 9ba1c88a93
commit b18b417fbf
10 changed files with 5 additions and 229 deletions

View File

@ -1,120 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import random
from dataclasses import dataclass
from typing import TYPE_CHECKING
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
logger = logging.getLogger()
logging.basicConfig(level=logging.INFO)
@dataclass
class RandomDropConnectorMetadata(KVConnectorMetadata):
req_meta: dict[str, list[int]]
class RandomDropConnector(KVConnectorBase_V1):
"""
A connector designed for fault tolerance testing by randomly dropping
kv data during the process of loading or receiving KV cache.
This class simulates real-world scenarios where requests or data
might be lost or timeout, allowing developers to test and validate the
system's ability to handle such failures.
Attributes:
finished_recving_kv_req_ids (set[str]): A set of request IDs that
have completed receiving KV cache data.
finished_loading_dict (dict[str, int]): A dictionary that tracks
the actual number of tokens loaded from the remote KV store
for each completed request. The keys are request IDs, and
the values are the corresponding token counts.
"""
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self.failure_request: list[str] = []
self._reqs_need_recv: dict[str, list[int]] = {}
self._finish_load: dict[str, int] = {}
self.chunk_size = 256
############################################################
# Scheduler Side Methods
############################################################
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
if request.request_id in self.failure_request:
self.failure_request.remove(request.request_id)
return 0, False
num_external_hit_tokens = request.num_prompt_tokens - 1
logger.info(
"request %s num_prompt_tokens %d num_external_hit_tokens %d",
request.request_id, request.num_prompt_tokens,
num_external_hit_tokens)
return num_external_hit_tokens, True
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
if num_external_tokens > 0:
self._reqs_need_recv[
request.
request_id] = request.prompt_token_ids[:num_external_tokens]
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
req_meta = self._reqs_need_recv.copy()
self._reqs_need_recv.clear()
return RandomDropConnectorMetadata(req_meta)
def add_failure_request(self, request: "Request"):
self.failure_request.append(request.request_id)
def start_load_kv(self, forward_context, **kwargs) -> None:
for request_id, hit_tokens in self._get_connector_metadata(
).req_meta.items():
num_actual_load_tokens = self.load_kv(request_id, hit_tokens)
logger.info("request %s hit_tokens %d num_actual_load_tokens %d",
request_id, len(hit_tokens), num_actual_load_tokens)
self._finish_load[request_id] = num_actual_load_tokens
def wait_for_layer_load(self, layer_name: str) -> None:
pass
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
pass
def wait_for_save(self):
pass
def load_kv(self, request_id, hit_tokens):
num_actual_load_tokens = random.randint(0, len(hit_tokens))
return num_actual_load_tokens
def get_finished_loading(self) -> dict[str, int]:
if not self._finish_load:
return {}
finished_loading = self._finish_load.copy()
self._finish_load.clear()
return finished_loading

View File

@ -1,16 +0,0 @@
#!/bin/bash
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
export PYTHONPATH=$PYTHONPATH:$SCRIPT_DIR
vllm serve DeepSeek-V2-Lite-Chat \
--trust-remote-code \
--served-model-name vllm_cpu_offload \
--max-model-len 32768 \
--no-enable-prefix-caching \
--max-seq-len-to-capture 10000 \
--max-num-seqs 64 \
--gpu-memory-utilization 0.9 \
--host 0.0.0.0 \
-tp 2 \
--kv-transfer-config '{"kv_connector":"RandomDropConnector","kv_role":"kv_both","kv_connector_module_path":"random_drop_connector"}'

View File

@ -139,27 +139,13 @@ class KVOutputAggregator:
finished_set.add(req_id)
del remaining_count_dict[req_id]
def update_finished_load_dict(worker_finished_loading_dict: dict[str,
int],
finished_loading_dict: dict[str, int]):
for req_id, num_actual_load_tokens in (worker_finished_loading_dict
or {}).items():
if req_id in finished_loading_dict:
finished_loading_dict[req_id] = min(
finished_loading_dict[req_id], num_actual_load_tokens)
else:
finished_loading_dict[req_id] = num_actual_load_tokens
finished_sending = set[str]()
finished_recving = set[str]()
finished_loading_dict: dict[str, int] = {}
for output in outputs:
update_finished_set(output.finished_sending,
self._send_remaining_count, finished_sending)
update_finished_set(output.finished_recving,
self._recv_remaining_count, finished_recving)
update_finished_load_dict(output.finished_loading_dict,
finished_loading_dict)
# select output of the worker specified by output_rank
output = outputs[output_rank]
@ -171,7 +157,7 @@ class KVOutputAggregator:
# send/recv
output.finished_sending = finished_sending if finished_sending else None
output.finished_recving = finished_recving if finished_recving else None
output.finished_loading_dict = finished_loading_dict or None
return output
def async_aggregate(self,

View File

@ -28,9 +28,6 @@ The class provides the following primitives:
get_finished() - called with ids of finished requests, returns
ids of requests that have completed async sending/recving.
get_finished_loading() - called with scheduler outputs, returns
a dictionary that the keys are request IDs and the values are
the actual number of tokens loaded from the remote KV cache
"""
import enum
@ -222,23 +219,6 @@ class KVConnectorBase_V1(ABC):
"""
return None, None
def get_finished_loading(
self, scheduler_output: "SchedulerOutput") -> dict[str, int]:
"""
Retrieves the actual number of tokens loaded for requests that have
completed the asynchronous loading process from the remote KV cache.
This function is used by the scheduler process (via the Executors)
to track the progress of requests and determine which requests have
successfully finished loading their KV cache data.
Returns:
A dictionary where the keys are request IDs and the values are the
corresponding number of tokens that have been successfully loaded
for each request.
"""
return {}
# ==============================
# Scheduler-side methods
# ==============================

View File

@ -1167,8 +1167,6 @@ class IntermediateTensors:
# [req_ids]
finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None
#req_id -> num_actual_load_tokens
finished_loading_dict: Optional[dict[str, int]] = None
def __init__(self, tensors):
# manually define this function, so that

View File

@ -118,9 +118,6 @@ class Scheduler(SchedulerInterface):
# KV Connector: requests in process of async KV loading or recving
self.finished_recving_kv_req_ids: set[str] = set()
# The keys are request IDs, and the values are corresponding token
# count that have been successfully loaded from the remote KV store
self.finished_loading_dict: dict[str, int] = {}
# Encoder-related.
# Calculate encoder cache size if applicable
@ -1097,27 +1094,6 @@ class Scheduler(SchedulerInterface):
(block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id)
return self.connector.request_finished(request, block_ids)
def _update_actual_load_token_num_from_remote_kv(self,
request: Request) -> bool:
num_actual_load_tokens = self.finished_loading_dict.pop(
request.request_id)
num_computed_tokens = num_actual_load_tokens
assert self.connector is not None
if num_actual_load_tokens <= 0 and hasattr(self.connector,
"add_failure_request"):
self.connector.add_failure_request(request)
return True
if num_actual_load_tokens == request.num_tokens:
num_computed_tokens -= 1
self.kv_cache_manager.cache_blocks(request, num_computed_tokens)
# Update the request state for scheduling.
request.num_computed_tokens = num_computed_tokens
return True
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
"""
KV Connector: check if the request_id is finished_recving.
@ -1131,9 +1107,6 @@ class Scheduler(SchedulerInterface):
WAITING_FOR_REMOTE_KV.
"""
assert self.connector is not None
if request.request_id in self.finished_loading_dict:
return self._update_actual_load_token_num_from_remote_kv(request)
if request.request_id not in self.finished_recving_kv_req_ids:
return False
@ -1172,6 +1145,3 @@ class Scheduler(SchedulerInterface):
for req_id in (model_runner_output.finished_sending or ()):
logger.debug("Finished sending KV transfer for request %s", req_id)
self._free_blocks(self.requests[req_id])
if model_runner_output.finished_loading_dict:
self.finished_loading_dict.update(
model_runner_output.finished_loading_dict)

View File

@ -107,8 +107,6 @@ class ModelRunnerOutput:
# [req_ids]
finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None
# req_id -> actual_load_token from connector
finished_loading_dict: Optional[dict[str, int]] = None
# req_id -> num_nans_in_logits
num_nans_in_logits: Optional[dict[str, int]] = None
@ -123,5 +121,4 @@ EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
pooler_output=[],
finished_sending=None,
finished_recving=None,
finished_loading_dict=None,
num_nans_in_logits=None)

View File

@ -1375,7 +1375,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_scheduled_tokens_np: np.ndarray,
finished_sending: Optional[set[str]],
finished_recving: Optional[set[str]],
finished_loading_dict: Optional[dict[str, int]],
) -> ModelRunnerOutput:
assert self.input_batch.num_reqs ==\
len(self.input_batch.pooling_params), \
@ -1412,7 +1411,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pooler_output=pooler_output,
finished_sending=finished_sending,
finished_recving=finished_recving,
finished_loading_dict=finished_loading_dict,
)
@torch.inference_mode()
@ -1532,7 +1530,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
finished_loading_dict = self.get_finished_loading(scheduler_output)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
@ -1550,11 +1547,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states.
if not broadcast_pp_output:
if (finished_sending or finished_recving
or finished_loading_dict):
if finished_sending or finished_recving:
hidden_states.finished_sending = finished_sending
hidden_states.finished_recving = finished_recving
hidden_states.finished_loading_dict = finished_loading_dict
return hidden_states
assert isinstance(hidden_states, IntermediateTensors)
get_pp_group().send_tensor_dict(hidden_states.tensors,
@ -1564,7 +1559,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.input_batch.pooling_params:
return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np, finished_sending,
finished_recving, finished_loading_dict)
finished_recving)
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
@ -1716,7 +1711,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pooler_output=[],
finished_sending=finished_sending,
finished_recving=finished_recving,
finished_loading_dict=finished_loading_dict,
num_nans_in_logits=num_nans_in_logits,
)

View File

@ -359,12 +359,10 @@ class Worker(WorkerBase):
# In case of PP with kv transfer, we need to pass through the
# finished_sending and finished_recving buffers.
new_output = EMPTY_MODEL_RUNNER_OUTPUT
if (output.finished_sending or output.finished_recving
or output.finished_loading_dict):
if output.finished_sending or output.finished_recving:
new_output = copy.copy(new_output)
new_output.finished_sending = output.finished_sending
new_output.finished_recving = output.finished_recving
new_output.finished_loading_dict = output.finished_loading_dict
output = new_output
assert isinstance(output, ModelRunnerOutput)

View File

@ -53,14 +53,6 @@ class KVConnectorModelRunnerMixin:
scheduler_output.finished_req_ids)
return None, None
@staticmethod
def get_finished_loading(
scheduler_output: "SchedulerOutput", ) -> dict[str, int]:
if has_kv_transfer_group():
return get_kv_transfer_group().get_finished_loading(
scheduler_output)
return {}
def kv_connector_no_forward(self, scheduler_output: "SchedulerOutput",
vllm_config: VllmConfig) -> ModelRunnerOutput:
# KV send/recv even if no work to do.
@ -68,14 +60,11 @@ class KVConnectorModelRunnerMixin:
self.maybe_setup_kv_connector(scheduler_output)
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
finished_loading_dict = self.get_finished_loading(scheduler_output)
if (not finished_sending and not finished_recving
and not finished_loading_dict):
if not finished_sending and not finished_recving:
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.finished_sending = finished_sending
output.finished_recving = finished_recving
output.finished_loading_dict = finished_loading_dict
return output