mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Revert "[V1] Exception Handling when Loading KV Cache from Remote Store" (#21778)
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
This commit is contained in:
@ -1,120 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import logging
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = logging.getLogger()
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RandomDropConnectorMetadata(KVConnectorMetadata):
|
||||
req_meta: dict[str, list[int]]
|
||||
|
||||
|
||||
class RandomDropConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
A connector designed for fault tolerance testing by randomly dropping
|
||||
kv data during the process of loading or receiving KV cache.
|
||||
|
||||
This class simulates real-world scenarios where requests or data
|
||||
might be lost or timeout, allowing developers to test and validate the
|
||||
system's ability to handle such failures.
|
||||
|
||||
Attributes:
|
||||
finished_recving_kv_req_ids (set[str]): A set of request IDs that
|
||||
have completed receiving KV cache data.
|
||||
finished_loading_dict (dict[str, int]): A dictionary that tracks
|
||||
the actual number of tokens loaded from the remote KV store
|
||||
for each completed request. The keys are request IDs, and
|
||||
the values are the corresponding token counts.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
|
||||
self.failure_request: list[str] = []
|
||||
self._reqs_need_recv: dict[str, list[int]] = {}
|
||||
self._finish_load: dict[str, int] = {}
|
||||
|
||||
self.chunk_size = 256
|
||||
|
||||
############################################################
|
||||
# Scheduler Side Methods
|
||||
############################################################
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
if request.request_id in self.failure_request:
|
||||
self.failure_request.remove(request.request_id)
|
||||
return 0, False
|
||||
num_external_hit_tokens = request.num_prompt_tokens - 1
|
||||
logger.info(
|
||||
"request %s num_prompt_tokens %d num_external_hit_tokens %d",
|
||||
request.request_id, request.num_prompt_tokens,
|
||||
num_external_hit_tokens)
|
||||
return num_external_hit_tokens, True
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
if num_external_tokens > 0:
|
||||
self._reqs_need_recv[
|
||||
request.
|
||||
request_id] = request.prompt_token_ids[:num_external_tokens]
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
req_meta = self._reqs_need_recv.copy()
|
||||
self._reqs_need_recv.clear()
|
||||
return RandomDropConnectorMetadata(req_meta)
|
||||
|
||||
def add_failure_request(self, request: "Request"):
|
||||
self.failure_request.append(request.request_id)
|
||||
|
||||
def start_load_kv(self, forward_context, **kwargs) -> None:
|
||||
for request_id, hit_tokens in self._get_connector_metadata(
|
||||
).req_meta.items():
|
||||
num_actual_load_tokens = self.load_kv(request_id, hit_tokens)
|
||||
logger.info("request %s hit_tokens %d num_actual_load_tokens %d",
|
||||
request_id, len(hit_tokens), num_actual_load_tokens)
|
||||
self._finish_load[request_id] = num_actual_load_tokens
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
pass
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
pass
|
||||
|
||||
def load_kv(self, request_id, hit_tokens):
|
||||
num_actual_load_tokens = random.randint(0, len(hit_tokens))
|
||||
return num_actual_load_tokens
|
||||
|
||||
def get_finished_loading(self) -> dict[str, int]:
|
||||
if not self._finish_load:
|
||||
return {}
|
||||
finished_loading = self._finish_load.copy()
|
||||
self._finish_load.clear()
|
||||
|
||||
return finished_loading
|
@ -1,16 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
export PYTHONPATH=$PYTHONPATH:$SCRIPT_DIR
|
||||
|
||||
vllm serve DeepSeek-V2-Lite-Chat \
|
||||
--trust-remote-code \
|
||||
--served-model-name vllm_cpu_offload \
|
||||
--max-model-len 32768 \
|
||||
--no-enable-prefix-caching \
|
||||
--max-seq-len-to-capture 10000 \
|
||||
--max-num-seqs 64 \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--host 0.0.0.0 \
|
||||
-tp 2 \
|
||||
--kv-transfer-config '{"kv_connector":"RandomDropConnector","kv_role":"kv_both","kv_connector_module_path":"random_drop_connector"}'
|
@ -139,27 +139,13 @@ class KVOutputAggregator:
|
||||
finished_set.add(req_id)
|
||||
del remaining_count_dict[req_id]
|
||||
|
||||
def update_finished_load_dict(worker_finished_loading_dict: dict[str,
|
||||
int],
|
||||
finished_loading_dict: dict[str, int]):
|
||||
for req_id, num_actual_load_tokens in (worker_finished_loading_dict
|
||||
or {}).items():
|
||||
if req_id in finished_loading_dict:
|
||||
finished_loading_dict[req_id] = min(
|
||||
finished_loading_dict[req_id], num_actual_load_tokens)
|
||||
else:
|
||||
finished_loading_dict[req_id] = num_actual_load_tokens
|
||||
|
||||
finished_sending = set[str]()
|
||||
finished_recving = set[str]()
|
||||
finished_loading_dict: dict[str, int] = {}
|
||||
for output in outputs:
|
||||
update_finished_set(output.finished_sending,
|
||||
self._send_remaining_count, finished_sending)
|
||||
update_finished_set(output.finished_recving,
|
||||
self._recv_remaining_count, finished_recving)
|
||||
update_finished_load_dict(output.finished_loading_dict,
|
||||
finished_loading_dict)
|
||||
|
||||
# select output of the worker specified by output_rank
|
||||
output = outputs[output_rank]
|
||||
@ -171,7 +157,7 @@ class KVOutputAggregator:
|
||||
# send/recv
|
||||
output.finished_sending = finished_sending if finished_sending else None
|
||||
output.finished_recving = finished_recving if finished_recving else None
|
||||
output.finished_loading_dict = finished_loading_dict or None
|
||||
|
||||
return output
|
||||
|
||||
def async_aggregate(self,
|
||||
|
@ -28,9 +28,6 @@ The class provides the following primitives:
|
||||
|
||||
get_finished() - called with ids of finished requests, returns
|
||||
ids of requests that have completed async sending/recving.
|
||||
get_finished_loading() - called with scheduler outputs, returns
|
||||
a dictionary that the keys are request IDs and the values are
|
||||
the actual number of tokens loaded from the remote KV cache
|
||||
"""
|
||||
|
||||
import enum
|
||||
@ -222,23 +219,6 @@ class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
return None, None
|
||||
|
||||
def get_finished_loading(
|
||||
self, scheduler_output: "SchedulerOutput") -> dict[str, int]:
|
||||
"""
|
||||
Retrieves the actual number of tokens loaded for requests that have
|
||||
completed the asynchronous loading process from the remote KV cache.
|
||||
|
||||
This function is used by the scheduler process (via the Executors)
|
||||
to track the progress of requests and determine which requests have
|
||||
successfully finished loading their KV cache data.
|
||||
|
||||
Returns:
|
||||
A dictionary where the keys are request IDs and the values are the
|
||||
corresponding number of tokens that have been successfully loaded
|
||||
for each request.
|
||||
"""
|
||||
return {}
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
@ -1167,8 +1167,6 @@ class IntermediateTensors:
|
||||
# [req_ids]
|
||||
finished_sending: Optional[set[str]] = None
|
||||
finished_recving: Optional[set[str]] = None
|
||||
#req_id -> num_actual_load_tokens
|
||||
finished_loading_dict: Optional[dict[str, int]] = None
|
||||
|
||||
def __init__(self, tensors):
|
||||
# manually define this function, so that
|
||||
|
@ -118,9 +118,6 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
# KV Connector: requests in process of async KV loading or recving
|
||||
self.finished_recving_kv_req_ids: set[str] = set()
|
||||
# The keys are request IDs, and the values are corresponding token
|
||||
# count that have been successfully loaded from the remote KV store
|
||||
self.finished_loading_dict: dict[str, int] = {}
|
||||
|
||||
# Encoder-related.
|
||||
# Calculate encoder cache size if applicable
|
||||
@ -1097,27 +1094,6 @@ class Scheduler(SchedulerInterface):
|
||||
(block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id)
|
||||
return self.connector.request_finished(request, block_ids)
|
||||
|
||||
def _update_actual_load_token_num_from_remote_kv(self,
|
||||
request: Request) -> bool:
|
||||
|
||||
num_actual_load_tokens = self.finished_loading_dict.pop(
|
||||
request.request_id)
|
||||
num_computed_tokens = num_actual_load_tokens
|
||||
assert self.connector is not None
|
||||
if num_actual_load_tokens <= 0 and hasattr(self.connector,
|
||||
"add_failure_request"):
|
||||
self.connector.add_failure_request(request)
|
||||
return True
|
||||
|
||||
if num_actual_load_tokens == request.num_tokens:
|
||||
num_computed_tokens -= 1
|
||||
|
||||
self.kv_cache_manager.cache_blocks(request, num_computed_tokens)
|
||||
|
||||
# Update the request state for scheduling.
|
||||
request.num_computed_tokens = num_computed_tokens
|
||||
return True
|
||||
|
||||
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
|
||||
"""
|
||||
KV Connector: check if the request_id is finished_recving.
|
||||
@ -1131,9 +1107,6 @@ class Scheduler(SchedulerInterface):
|
||||
WAITING_FOR_REMOTE_KV.
|
||||
"""
|
||||
assert self.connector is not None
|
||||
if request.request_id in self.finished_loading_dict:
|
||||
return self._update_actual_load_token_num_from_remote_kv(request)
|
||||
|
||||
if request.request_id not in self.finished_recving_kv_req_ids:
|
||||
return False
|
||||
|
||||
@ -1172,6 +1145,3 @@ class Scheduler(SchedulerInterface):
|
||||
for req_id in (model_runner_output.finished_sending or ()):
|
||||
logger.debug("Finished sending KV transfer for request %s", req_id)
|
||||
self._free_blocks(self.requests[req_id])
|
||||
if model_runner_output.finished_loading_dict:
|
||||
self.finished_loading_dict.update(
|
||||
model_runner_output.finished_loading_dict)
|
||||
|
@ -107,8 +107,6 @@ class ModelRunnerOutput:
|
||||
# [req_ids]
|
||||
finished_sending: Optional[set[str]] = None
|
||||
finished_recving: Optional[set[str]] = None
|
||||
# req_id -> actual_load_token from connector
|
||||
finished_loading_dict: Optional[dict[str, int]] = None
|
||||
|
||||
# req_id -> num_nans_in_logits
|
||||
num_nans_in_logits: Optional[dict[str, int]] = None
|
||||
@ -123,5 +121,4 @@ EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
|
||||
pooler_output=[],
|
||||
finished_sending=None,
|
||||
finished_recving=None,
|
||||
finished_loading_dict=None,
|
||||
num_nans_in_logits=None)
|
||||
|
@ -1375,7 +1375,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_scheduled_tokens_np: np.ndarray,
|
||||
finished_sending: Optional[set[str]],
|
||||
finished_recving: Optional[set[str]],
|
||||
finished_loading_dict: Optional[dict[str, int]],
|
||||
) -> ModelRunnerOutput:
|
||||
assert self.input_batch.num_reqs ==\
|
||||
len(self.input_batch.pooling_params), \
|
||||
@ -1412,7 +1411,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
pooler_output=pooler_output,
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
finished_loading_dict=finished_loading_dict,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -1532,7 +1530,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.maybe_wait_for_kv_save()
|
||||
finished_sending, finished_recving = (
|
||||
self.get_finished_kv_transfers(scheduler_output))
|
||||
finished_loading_dict = self.get_finished_loading(scheduler_output)
|
||||
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
@ -1550,11 +1547,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if not get_pp_group().is_last_rank:
|
||||
# For mid-pipeline stages, return the hidden states.
|
||||
if not broadcast_pp_output:
|
||||
if (finished_sending or finished_recving
|
||||
or finished_loading_dict):
|
||||
if finished_sending or finished_recving:
|
||||
hidden_states.finished_sending = finished_sending
|
||||
hidden_states.finished_recving = finished_recving
|
||||
hidden_states.finished_loading_dict = finished_loading_dict
|
||||
return hidden_states
|
||||
assert isinstance(hidden_states, IntermediateTensors)
|
||||
get_pp_group().send_tensor_dict(hidden_states.tensors,
|
||||
@ -1564,7 +1559,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if self.input_batch.pooling_params:
|
||||
return self._pool(hidden_states, num_scheduled_tokens,
|
||||
num_scheduled_tokens_np, finished_sending,
|
||||
finished_recving, finished_loading_dict)
|
||||
finished_recving)
|
||||
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
@ -1716,7 +1711,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
pooler_output=[],
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
finished_loading_dict=finished_loading_dict,
|
||||
num_nans_in_logits=num_nans_in_logits,
|
||||
)
|
||||
|
||||
|
@ -359,12 +359,10 @@ class Worker(WorkerBase):
|
||||
# In case of PP with kv transfer, we need to pass through the
|
||||
# finished_sending and finished_recving buffers.
|
||||
new_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
if (output.finished_sending or output.finished_recving
|
||||
or output.finished_loading_dict):
|
||||
if output.finished_sending or output.finished_recving:
|
||||
new_output = copy.copy(new_output)
|
||||
new_output.finished_sending = output.finished_sending
|
||||
new_output.finished_recving = output.finished_recving
|
||||
new_output.finished_loading_dict = output.finished_loading_dict
|
||||
output = new_output
|
||||
|
||||
assert isinstance(output, ModelRunnerOutput)
|
||||
|
@ -53,14 +53,6 @@ class KVConnectorModelRunnerMixin:
|
||||
scheduler_output.finished_req_ids)
|
||||
return None, None
|
||||
|
||||
@staticmethod
|
||||
def get_finished_loading(
|
||||
scheduler_output: "SchedulerOutput", ) -> dict[str, int]:
|
||||
if has_kv_transfer_group():
|
||||
return get_kv_transfer_group().get_finished_loading(
|
||||
scheduler_output)
|
||||
return {}
|
||||
|
||||
def kv_connector_no_forward(self, scheduler_output: "SchedulerOutput",
|
||||
vllm_config: VllmConfig) -> ModelRunnerOutput:
|
||||
# KV send/recv even if no work to do.
|
||||
@ -68,14 +60,11 @@ class KVConnectorModelRunnerMixin:
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
finished_sending, finished_recving = (
|
||||
self.get_finished_kv_transfers(scheduler_output))
|
||||
finished_loading_dict = self.get_finished_loading(scheduler_output)
|
||||
|
||||
if (not finished_sending and not finished_recving
|
||||
and not finished_loading_dict):
|
||||
if not finished_sending and not finished_recving:
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.finished_sending = finished_sending
|
||||
output.finished_recving = finished_recving
|
||||
output.finished_loading_dict = finished_loading_dict
|
||||
return output
|
||||
|
Reference in New Issue
Block a user